Cleanup PrimDef #479

Merged
sb10q merged 4 commits from clean-primdef into master 2024-07-27 21:57:47 +08:00
2 changed files with 103 additions and 78 deletions

View File

@ -346,8 +346,8 @@ impl<'a> BuiltinBuilder<'a> {
let (is_some_ty, unwrap_ty, option_tvar) = let (is_some_ty, unwrap_ty, option_tvar) =
if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() { if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() {
( (
*fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(), *fields.get(&PrimDef::FunOptionIsSome.simple_name().into()).unwrap(),
*fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(), *fields.get(&PrimDef::FunOptionUnwrap.simple_name().into()).unwrap(),
iter_type_vars(params).next().unwrap(), iter_type_vars(params).next().unwrap(),
) )
} else { } else {
@ -362,9 +362,9 @@ impl<'a> BuiltinBuilder<'a> {
let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap(); let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap();
let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap(); let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap();
let ndarray_copy_ty = let ndarray_copy_ty =
*ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap(); *ndarray_fields.get(&PrimDef::FunNDArrayCopy.simple_name().into()).unwrap();
let ndarray_fill_ty = let ndarray_fill_ty =
*ndarray_fields.get(&PrimDef::NDArrayFill.simple_name().into()).unwrap(); *ndarray_fields.get(&PrimDef::FunNDArrayFill.simple_name().into()).unwrap();
let num_ty = unifier.get_fresh_var_with_range( let num_ty = unifier.get_fresh_var_with_range(
&[int32, int64, float, boolean, uint32, uint64], &[int32, int64, float, boolean, uint32, uint64],
@ -464,14 +464,14 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::Exception => self.build_exception_class_related(prim), PrimDef::Exception => self.build_exception_class_related(prim),
PrimDef::Option PrimDef::Option
| PrimDef::OptionIsSome | PrimDef::FunOptionIsSome
| PrimDef::OptionIsNone | PrimDef::FunOptionIsNone
| PrimDef::OptionUnwrap | PrimDef::FunOptionUnwrap
| PrimDef::FunSome => self.build_option_class_related(prim), | PrimDef::FunSome => self.build_option_class_related(prim),
PrimDef::List => self.build_list_class_related(prim), PrimDef::List => self.build_list_class_related(prim),
PrimDef::NDArray | PrimDef::NDArrayCopy | PrimDef::NDArrayFill => { PrimDef::NDArray | PrimDef::FunNDArrayCopy | PrimDef::FunNDArrayFill => {
self.build_ndarray_class_related(prim) self.build_ndarray_class_related(prim)
} }
@ -794,9 +794,9 @@ impl<'a> BuiltinBuilder<'a> {
prim, prim,
&[ &[
PrimDef::Option, PrimDef::Option,
PrimDef::OptionIsSome, PrimDef::FunOptionIsSome,
PrimDef::OptionIsNone, PrimDef::FunOptionIsNone,
PrimDef::OptionUnwrap, PrimDef::FunOptionUnwrap,
PrimDef::FunSome, PrimDef::FunSome,
], ],
); );
@ -809,9 +809,9 @@ impl<'a> BuiltinBuilder<'a> {
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(), attributes: Vec::default(),
methods: vec![ methods: vec![
Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0), Self::create_method(PrimDef::FunOptionIsSome, self.is_some_ty.0),
Self::create_method(PrimDef::OptionIsNone, self.is_some_ty.0), Self::create_method(PrimDef::FunOptionIsNone, self.is_some_ty.0),
Self::create_method(PrimDef::OptionUnwrap, self.unwrap_ty.0), Self::create_method(PrimDef::FunOptionUnwrap, self.unwrap_ty.0),
], ],
ancestors: vec![TypeAnnotation::CustomClass { ancestors: vec![TypeAnnotation::CustomClass {
id: prim.id(), id: prim.id(),
@ -822,7 +822,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::OptionUnwrap => TopLevelDef::Function { PrimDef::FunOptionUnwrap => TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.unwrap_ty.0, signature: self.unwrap_ty.0,
@ -836,7 +836,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::OptionIsNone | PrimDef::OptionIsSome => TopLevelDef::Function { PrimDef::FunOptionIsNone | PrimDef::FunOptionIsSome => TopLevelDef::Function {
name: prim.name().to_string(), name: prim.name().to_string(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.is_some_ty.0, signature: self.is_some_ty.0,
@ -857,10 +857,10 @@ impl<'a> BuiltinBuilder<'a> {
}; };
let returned_int = match prim { let returned_int = match prim {
PrimDef::OptionIsNone => { PrimDef::FunOptionIsNone => {
ctx.builder.build_is_null(ptr, prim.simple_name()) ctx.builder.build_is_null(ptr, prim.simple_name())
} }
PrimDef::OptionIsSome => { PrimDef::FunOptionIsSome => {
ctx.builder.build_is_not_null(ptr, prim.simple_name()) ctx.builder.build_is_not_null(ptr, prim.simple_name())
} }
_ => unreachable!(), _ => unreachable!(),
@ -933,7 +933,7 @@ impl<'a> BuiltinBuilder<'a> {
fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef { fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed( debug_assert_prim_is_allowed(
prim, prim,
&[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill], &[PrimDef::NDArray, PrimDef::FunNDArrayCopy, PrimDef::FunNDArrayFill],
); );
match prim { match prim {
@ -944,8 +944,8 @@ impl<'a> BuiltinBuilder<'a> {
fields: Vec::default(), fields: Vec::default(),
attributes: Vec::default(), attributes: Vec::default(),
methods: vec![ methods: vec![
Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0), Self::create_method(PrimDef::FunNDArrayCopy, self.ndarray_copy_ty.0),
Self::create_method(PrimDef::NDArrayFill, self.ndarray_fill_ty.0), Self::create_method(PrimDef::FunNDArrayFill, self.ndarray_fill_ty.0),
], ],
ancestors: Vec::default(), ancestors: Vec::default(),
constructor: None, constructor: None,
@ -953,7 +953,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::NDArrayCopy => TopLevelDef::Function { PrimDef::FunNDArrayCopy => TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_copy_ty.0, signature: self.ndarray_copy_ty.0,
@ -970,7 +970,7 @@ impl<'a> BuiltinBuilder<'a> {
loc: None, loc: None,
}, },
PrimDef::NDArrayFill => TopLevelDef::Function { PrimDef::FunNDArrayFill => TopLevelDef::Function {
name: prim.name().into(), name: prim.name().into(),
simple_name: prim.simple_name().into(), simple_name: prim.simple_name().into(),
signature: self.ndarray_fill_ty.0, signature: self.ndarray_fill_ty.0,

View File

@ -27,17 +27,22 @@ pub enum PrimDef {
List, List,
NDArray, NDArray,
// Member Functions // Option methods
OptionIsSome, FunOptionIsSome,
OptionIsNone, FunOptionIsNone,
OptionUnwrap, FunOptionUnwrap,
NDArrayCopy,
NDArrayFill, // Option-related functions
FunInt32, FunSome,
FunInt64,
FunUInt32, // NDArray methods
FunUInt64, FunNDArrayCopy,
FunFloat, FunNDArrayFill,
// Range methods
FunRangeInit,
// NumPy factory functions
FunNpNDArray, FunNpNDArray,
FunNpEmpty, FunNpEmpty,
FunNpZeros, FunNpZeros,
@ -46,28 +51,17 @@ pub enum PrimDef {
FunNpArray, FunNpArray,
FunNpEye, FunNpEye,
FunNpIdentity, FunNpIdentity,
FunRound,
FunRound64, // Miscellaneous NumPy & SciPy functions
FunNpRound, FunNpRound,
FunRangeInit,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunNpFloor, FunNpFloor,
FunCeil,
FunCeil64,
FunNpCeil, FunNpCeil,
FunLen,
FunMin,
FunNpMin, FunNpMin,
FunNpMinimum, FunNpMinimum,
FunNpArgmin, FunNpArgmin,
FunMax,
FunNpMax, FunNpMax,
FunNpMaximum, FunNpMaximum,
FunNpArgmax, FunNpArgmax,
FunAbs,
FunNpIsNan, FunNpIsNan,
FunNpIsInf, FunNpIsInf,
FunNpSin, FunNpSin,
@ -106,8 +100,24 @@ pub enum PrimDef {
FunNpHypot, FunNpHypot,
FunNpNextAfter, FunNpNextAfter,
// Top-Level Functions // Miscellaneous Python & NAC3 functions
FunSome, FunInt32,
FunInt64,
FunUInt32,
FunUInt64,
FunFloat,
FunRound,
FunRound64,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunCeil,
FunCeil64,
FunLen,
FunMin,
FunMax,
FunAbs,
} }
/// Associated details of a [`PrimDef`] /// Associated details of a [`PrimDef`]
@ -173,6 +183,7 @@ impl PrimDef {
} }
match self { match self {
// Classes
PrimDef::Int32 => class("int32", |primitives| primitives.int32), PrimDef::Int32 => class("int32", |primitives| primitives.int32),
PrimDef::Int64 => class("int64", |primitives| primitives.int64), PrimDef::Int64 => class("int64", |primitives| primitives.int64),
PrimDef::Float => class("float", |primitives| primitives.float), PrimDef::Float => class("float", |primitives| primitives.float),
@ -184,18 +195,25 @@ impl PrimDef {
PrimDef::UInt32 => class("uint32", |primitives| primitives.uint32), PrimDef::UInt32 => class("uint32", |primitives| primitives.uint32),
PrimDef::UInt64 => class("uint64", |primitives| primitives.uint64), PrimDef::UInt64 => class("uint64", |primitives| primitives.uint64),
PrimDef::Option => class("Option", |primitives| primitives.option), PrimDef::Option => class("Option", |primitives| primitives.option),
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")),
PrimDef::List => class("list", |primitives| primitives.list), PrimDef::List => class("list", |primitives| primitives.list),
PrimDef::NDArray => class("ndarray", |primitives| primitives.ndarray), PrimDef::NDArray => class("ndarray", |primitives| primitives.ndarray),
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")), // Option methods
PrimDef::FunInt32 => fun("int32", None), PrimDef::FunOptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::FunInt64 => fun("int64", None), PrimDef::FunOptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::FunUInt32 => fun("uint32", None), PrimDef::FunOptionUnwrap => fun("Option.unwrap", Some("unwrap")),
PrimDef::FunUInt64 => fun("uint64", None),
PrimDef::FunFloat => fun("float", None), // Option-related functions
PrimDef::FunSome => fun("Some", None),
// NDArray methods
PrimDef::FunNDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::FunNDArrayFill => fun("ndarray.fill", Some("fill")),
// Range methods
PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
// NumPy factory functions
PrimDef::FunNpNDArray => fun("np_ndarray", None), PrimDef::FunNpNDArray => fun("np_ndarray", None),
PrimDef::FunNpEmpty => fun("np_empty", None), PrimDef::FunNpEmpty => fun("np_empty", None),
PrimDef::FunNpZeros => fun("np_zeros", None), PrimDef::FunNpZeros => fun("np_zeros", None),
@ -204,28 +222,17 @@ impl PrimDef {
PrimDef::FunNpArray => fun("np_array", None), PrimDef::FunNpArray => fun("np_array", None),
PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None), PrimDef::FunNpIdentity => fun("np_identity", None),
PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None), // Miscellaneous NumPy & SciPy functions
PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunNpFloor => fun("np_floor", None), PrimDef::FunNpFloor => fun("np_floor", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunNpCeil => fun("np_ceil", None), PrimDef::FunNpCeil => fun("np_ceil", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunNpMin => fun("np_min", None), PrimDef::FunNpMin => fun("np_min", None),
PrimDef::FunNpMinimum => fun("np_minimum", None), PrimDef::FunNpMinimum => fun("np_minimum", None),
PrimDef::FunNpArgmin => fun("np_argmin", None), PrimDef::FunNpArgmin => fun("np_argmin", None),
PrimDef::FunMax => fun("max", None),
PrimDef::FunNpMax => fun("np_max", None), PrimDef::FunNpMax => fun("np_max", None),
PrimDef::FunNpMaximum => fun("np_maximum", None), PrimDef::FunNpMaximum => fun("np_maximum", None),
PrimDef::FunNpArgmax => fun("np_argmax", None), PrimDef::FunNpArgmax => fun("np_argmax", None),
PrimDef::FunAbs => fun("abs", None),
PrimDef::FunNpIsNan => fun("np_isnan", None), PrimDef::FunNpIsNan => fun("np_isnan", None),
PrimDef::FunNpIsInf => fun("np_isinf", None), PrimDef::FunNpIsInf => fun("np_isinf", None),
PrimDef::FunNpSin => fun("np_sin", None), PrimDef::FunNpSin => fun("np_sin", None),
@ -263,7 +270,25 @@ impl PrimDef {
PrimDef::FunNpLdExp => fun("np_ldexp", None), PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None), PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunSome => fun("Some", None),
// Miscellaneous Python & NAC3 functions
PrimDef::FunInt32 => fun("int32", None),
PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunUInt32 => fun("uint32", None),
PrimDef::FunUInt64 => fun("uint64", None),
PrimDef::FunFloat => fun("float", None),
PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunMax => fun("max", None),
PrimDef::FunAbs => fun("abs", None),
} }
} }
} }
@ -414,9 +439,9 @@ impl TopLevelComposer {
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::Option.id(), obj_id: PrimDef::Option.id(),
fields: vec![ fields: vec![
(PrimDef::OptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)), (PrimDef::FunOptionIsSome.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::OptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)), (PrimDef::FunOptionIsNone.simple_name().into(), (is_some_type_fun_ty, true)),
(PrimDef::OptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)), (PrimDef::FunOptionUnwrap.simple_name().into(), (unwrap_fun_ty, true)),
] ]
.into_iter() .into_iter()
.collect::<HashMap<_, _>>(), .collect::<HashMap<_, _>>(),
@ -457,8 +482,8 @@ impl TopLevelComposer {
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PrimDef::NDArray.id(), obj_id: PrimDef::NDArray.id(),
fields: Mapping::from([ fields: Mapping::from([
(PrimDef::NDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)), (PrimDef::FunNDArrayCopy.simple_name().into(), (ndarray_copy_fun_ty, true)),
(PrimDef::NDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)), (PrimDef::FunNDArrayFill.simple_name().into(), (ndarray_fill_fun_ty, true)),
]), ]),
params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]), params: into_var_map([ndarray_dtype_tvar, ndarray_ndims_tvar]),
}); });