forked from M-Labs/nac3
1
0
Fork 0

core: Refactor range function into constructor

This commit is contained in:
David Mak 2024-07-08 14:22:19 +08:00 committed by sb10q
parent 9238a5e86e
commit 2cfb7a7e10
2 changed files with 164 additions and 134 deletions

View File

@ -14,10 +14,7 @@ use strum::IntoEnumIterator;
use crate::{ use crate::{
codegen::{ codegen::{
builtin_fns, builtin_fns,
classes::{ classes::{ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor},
ArrayLikeValue, NDArrayValue, ProxyType, ProxyValue, RangeType, RangeValue,
TypedArrayLikeAccessor,
},
expr::destructure_range, expr::destructure_range,
irrt::*, irrt::*,
numpy::*, numpy::*,
@ -460,9 +457,10 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::Float | PrimDef::Float
| PrimDef::Bool | PrimDef::Bool
| PrimDef::Str | PrimDef::Str
| PrimDef::Range
| PrimDef::None => Self::build_simple_primitive_class(prim), | PrimDef::None => Self::build_simple_primitive_class(prim),
PrimDef::Range | PrimDef::FunRangeInit => self.build_range_class_related(prim),
PrimDef::Exception => self.build_exception_class_related(prim), PrimDef::Exception => self.build_exception_class_related(prim),
PrimDef::Option PrimDef::Option
@ -494,7 +492,6 @@ impl<'a> BuiltinBuilder<'a> {
| PrimDef::FunNpEye | PrimDef::FunNpEye
| PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim),
PrimDef::FunRange => self.build_range_function(),
PrimDef::FunStr => self.build_str_function(), PrimDef::FunStr => self.build_str_function(),
PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => {
@ -599,7 +596,6 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::Float, PrimDef::Float,
PrimDef::Bool, PrimDef::Bool,
PrimDef::Str, PrimDef::Str,
PrimDef::Range,
PrimDef::None, PrimDef::None,
], ],
); );
@ -607,6 +603,165 @@ impl<'a> BuiltinBuilder<'a> {
TopLevelComposer::make_top_level_class_def(prim.id(), None, prim.name().into(), None, None) TopLevelComposer::make_top_level_class_def(prim.id(), None, prim.name().into(), None, None)
} }
fn build_range_class_related(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(prim, &[PrimDef::Range, PrimDef::FunRangeInit]);
let PrimitiveStore { int32, range, .. } = *self.primitives;
let make_ctor_signature = |unifier: &mut Unifier| {
unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "start".into(), ty: int32, default_value: None },
FuncArg {
name: "stop".into(),
ty: int32,
// placeholder
default_value: Some(SymbolValue::I32(0)),
},
FuncArg {
name: "step".into(),
ty: int32,
default_value: Some(SymbolValue::I32(1)),
},
],
ret: range,
vars: VarMap::default(),
}))
};
match prim {
PrimDef::Range => {
let fields = vec![
("start".into(), int32, true),
("stop".into(), int32, true),
("step".into(), int32, true),
];
let ctor_signature = make_ctor_signature(self.unifier);
TopLevelDef::Class {
name: prim.name().into(),
object_id: prim.id(),
type_vars: Vec::default(),
fields,
attributes: Vec::default(),
methods: vec![("__init__".into(), ctor_signature, PrimDef::FunRangeInit.id())],
ancestors: Vec::default(),
constructor: Some(ctor_signature),
resolver: None,
loc: None,
}
}
PrimDef::FunRangeInit => TopLevelDef::Function {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
signature: make_ctor_signature(self.unifier),
var_id: Vec::default(),
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, obj, _, args, generator| {
let (zelf_ty, zelf) = obj.unwrap();
let zelf =
zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value();
let zelf = RangeValue::from_ptr_val(zelf, Some("range"));
let mut start = None;
let mut stop = None;
let mut step = None;
let int32 = ctx.ctx.i32_type();
let ty_i32 = ctx.primitives.int32;
for (i, arg) in args.iter().enumerate() {
if arg.0 == Some("start".into()) {
start = Some(
arg.1
.clone()
.to_basic_value_enum(ctx, generator, ty_i32)?
.into_int_value(),
);
} else if arg.0 == Some("stop".into()) {
stop = Some(
arg.1
.clone()
.to_basic_value_enum(ctx, generator, ty_i32)?
.into_int_value(),
);
} else if arg.0 == Some("step".into()) {
step = Some(
arg.1
.clone()
.to_basic_value_enum(ctx, generator, ty_i32)?
.into_int_value(),
);
} else if i == 0 {
start = Some(
arg.1
.clone()
.to_basic_value_enum(ctx, generator, ty_i32)?
.into_int_value(),
);
} else if i == 1 {
stop = Some(
arg.1
.clone()
.to_basic_value_enum(ctx, generator, ty_i32)?
.into_int_value(),
);
} else if i == 2 {
step = Some(
arg.1
.clone()
.to_basic_value_enum(ctx, generator, ty_i32)?
.into_int_value(),
);
}
}
let step = match step {
Some(step) => {
// assert step != 0, throw exception if not
let not_zero = ctx
.builder
.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
)
.unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"range() step must not be zero",
[None, None, None],
ctx.current_loc,
);
step
}
None => int32.const_int(1, false),
};
let stop = stop.unwrap_or_else(|| {
let v = start.unwrap();
start = None;
v
});
let start = start.unwrap_or_else(|| int32.const_zero());
zelf.store_start(ctx, start);
zelf.store_end(ctx, stop);
zelf.store_step(ctx, step);
Ok(Some(zelf.as_base_value().into()))
},
)))),
loc: None,
},
_ => unreachable!(),
}
}
/// Build the class `Exception` and its associated methods. /// Build the class `Exception` and its associated methods.
fn build_exception_class_related(&self, prim: PrimDef) -> TopLevelDef { fn build_exception_class_related(&self, prim: PrimDef) -> TopLevelDef {
// NOTE: currently only contains the class `Exception` // NOTE: currently only contains the class `Exception`
@ -1170,131 +1325,6 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
/// Build the `range()` function.
fn build_range_function(&mut self) -> TopLevelDef {
let prim = PrimDef::FunRange;
let PrimitiveStore { int32, range, .. } = *self.primitives;
TopLevelDef::Function {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg { name: "start".into(), ty: int32, default_value: None },
FuncArg {
name: "stop".into(),
ty: int32,
// placeholder
default_value: Some(SymbolValue::I32(0)),
},
FuncArg {
name: "step".into(),
ty: int32,
default_value: Some(SymbolValue::I32(1)),
},
],
ret: range,
vars: VarMap::default(),
})),
var_id: Vec::default(),
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, _, args, generator| {
let mut start = None;
let mut stop = None;
let mut step = None;
let int32 = ctx.ctx.i32_type();
let ty_i32 = ctx.primitives.int32;
for (i, arg) in args.iter().enumerate() {
if arg.0 == Some("start".into()) {
start = Some(
arg.1
.clone()
.to_basic_value_enum(ctx, generator, ty_i32)?
.into_int_value(),
);
} else if arg.0 == Some("stop".into()) {
stop = Some(
arg.1
.clone()
.to_basic_value_enum(ctx, generator, ty_i32)?
.into_int_value(),
);
} else if arg.0 == Some("step".into()) {
step = Some(
arg.1
.clone()
.to_basic_value_enum(ctx, generator, ty_i32)?
.into_int_value(),
);
} else if i == 0 {
start = Some(
arg.1
.clone()
.to_basic_value_enum(ctx, generator, ty_i32)?
.into_int_value(),
);
} else if i == 1 {
stop = Some(
arg.1
.clone()
.to_basic_value_enum(ctx, generator, ty_i32)?
.into_int_value(),
);
} else if i == 2 {
step = Some(
arg.1
.clone()
.to_basic_value_enum(ctx, generator, ty_i32)?
.into_int_value(),
);
}
}
let step = match step {
Some(step) => {
// assert step != 0, throw exception if not
let not_zero = ctx
.builder
.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
)
.unwrap();
ctx.make_assert(
generator,
not_zero,
"0:ValueError",
"range() step must not be zero",
[None, None, None],
ctx.current_loc,
);
step
}
None => int32.const_int(1, false),
};
let stop = stop.unwrap_or_else(|| {
let v = start.unwrap();
start = None;
v
});
let start = start.unwrap_or_else(|| int32.const_zero());
let ptr = RangeType::new(ctx.ctx).new_value(generator, ctx, Some("range"));
ptr.store_start(ctx, start);
ptr.store_end(ctx, stop);
ptr.store_step(ctx, step);
Ok(Some(ptr.as_base_value().into()))
},
)))),
loc: None,
}
}
/// Build the `str()` function. /// Build the `str()` function.
fn build_str_function(&mut self) -> TopLevelDef { fn build_str_function(&mut self) -> TopLevelDef {
let prim = PrimDef::FunStr; let prim = PrimDef::FunStr;

View File

@ -49,7 +49,7 @@ pub enum PrimDef {
FunRound, FunRound,
FunRound64, FunRound64,
FunNpRound, FunNpRound,
FunRange, FunRangeInit,
FunStr, FunStr,
FunBool, FunBool,
FunFloor, FunFloor,
@ -203,7 +203,7 @@ impl PrimDef {
PrimDef::FunRound => fun("round", None), PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None), PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunRange => fun("range", None), PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
PrimDef::FunStr => fun("str", None), PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None), PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None), PrimDef::FunFloor => fun("floor", None),