From c4420e6ab977eea9be4ed7cec014e935881438d8 Mon Sep 17 00:00:00 2001 From: lyken Date: Wed, 12 Jun 2024 15:09:20 +0800 Subject: [PATCH] core: refactor `get_builtins()` --- nac3core/src/toplevel/builtins.rs | 3038 ++++++++++++----------------- nac3core/src/toplevel/helper.rs | 150 ++ 2 files changed, 1416 insertions(+), 1772 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index fdc157e3..a0dbb5de 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,5 +1,6 @@ use std::iter::once; +use helper::{debug_assert_prim_is_allowed, PrimDefDetails}; use indexmap::IndexMap; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -8,6 +9,7 @@ use inkwell::{ IntPredicate, }; use itertools::Either; +use strum::IntoEnumIterator; use crate::{ codegen::{ @@ -113,8 +115,8 @@ fn create_fn_by_codegen( ret_ty: Type, param_ty: &[(Type, &'static str)], codegen_callback: Box, -) -> Arc> { - Arc::new(RwLock::new(TopLevelDef::Function { +) -> TopLevelDef { + TopLevelDef::Function { name: name.into(), simple_name: name.into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { @@ -131,7 +133,7 @@ fn create_fn_by_codegen( resolver: None, codegen_callback: Some(Arc::new(GenCall::new(codegen_callback))), loc: None, - })) + } } /// Creates a NumPy [`TopLevelDef`] function using an LLVM intrinsic. @@ -148,7 +150,7 @@ fn create_fn_by_intrinsic( ret_ty: Type, params: &[(Type, &'static str)], intrinsic_fn: &'static str, -) -> Arc> { +) -> TopLevelDef { let param_tys = params.iter().map(|p| p.0).collect_vec(); create_fn_by_codegen( @@ -213,7 +215,7 @@ fn create_fn_by_extern( params: &[(Type, &'static str)], extern_fn: &'static str, attrs: &'static [&str], -) -> Arc> { +) -> TopLevelDef { let param_tys = params.iter().map(|p| p.0).collect_vec(); create_fn_by_codegen( @@ -273,658 +275,867 @@ fn create_fn_by_extern( } pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> BuiltinInfo { - let PrimitiveStore { - int32, - int64, - uint32, - uint64, - float, - bool: boolean, - range, - str: string, - ndarray, - .. - } = *primitives; + let top_level_def_list = BuiltinBuilder::new(unifier, primitives) + .build_all_builtins() + .into_iter() + .map(|tld| Arc::new(RwLock::new(tld))); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None); - let ndarray_float_2d = { - let value = match primitives.size_t { - 64 => SymbolValue::U64(2u64), - 32 => SymbolValue::U32(2u32), - _ => unreachable!(), - }; - let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None }); + let ast_list: Vec>> = (0..top_level_def_list.len()).map(|_| None).collect(); - make_ndarray_ty(unifier, primitives, Some(float), Some(ndims)) - }; - let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 }); - let num_ty = unifier.get_fresh_var_with_range( - &[int32, int64, float, boolean, uint32, uint64], - Some("N".into()), - None, - ); - let num_var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect(); + izip!(top_level_def_list, ast_list).collect_vec() +} - let new_type_or_ndarray_ty = - |unifier: &mut Unifier, primitives: &PrimitiveStore, scalar_ty: Type| { - let ndarray = make_ndarray_ty(unifier, primitives, Some(scalar_ty), None); +/// A helper enum used by [`BuiltinBuilder`] +#[derive(Clone, Copy)] +enum SizeVariant { + Bits32, + Bits64, +} - unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None) - }; +impl SizeVariant { + fn of_int(self, primitives: &PrimitiveStore) -> Type { + match self { + SizeVariant::Bits32 => primitives.int32, + SizeVariant::Bits64 => primitives.int64, + } + } +} - let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.0), None); - let float_or_ndarray_ty = - unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); - let float_or_ndarray_var_map: VarMap = - vec![(float_or_ndarray_ty.1, float_or_ndarray_ty.0)].into_iter().collect(); +struct BuiltinBuilder<'a> { + unifier: &'a mut Unifier, + primitives: &'a PrimitiveStore, - let num_or_ndarray_ty = - unifier.get_fresh_var_with_range(&[num_ty.0, ndarray_num_ty], Some("T".into()), None); - let num_or_ndarray_var_map: VarMap = - vec![(num_ty.1, num_ty.0), (num_or_ndarray_ty.1, num_or_ndarray_ty.0)] - .into_iter() - .collect(); + is_some_ty: (Type, bool), + unwrap_ty: (Type, bool), + option_tvar: (Type, u32), - let exception_fields = vec![ - ("__name__".into(), int32, true), - ("__file__".into(), string, true), - ("__line__".into(), int32, true), - ("__col__".into(), int32, true), - ("__func__".into(), string, true), - ("__message__".into(), string, true), - ("__param0__".into(), int64, true), - ("__param1__".into(), int64, true), - ("__param2__".into(), int64, true), - ]; + ndarray_dtype_tvar: (Type, u32), + ndarray_ndims_tvar: (Type, u32), + ndarray_copy_ty: (Type, bool), + ndarray_fill_ty: (Type, bool), - // for Option, is_some and is_none share the same type: () -> bool, - // and they are methods under the same class `Option` - let (is_some_ty, unwrap_ty, (option_ty_var, option_ty_var_id)) = - if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(primitives.option).as_ref() { - ( - *fields.get(&"is_some".into()).unwrap(), - *fields.get(&"unwrap".into()).unwrap(), - (*params.iter().next().unwrap().1, *params.iter().next().unwrap().0), - ) - } else { + list_int32: Type, + + num_ty: (Type, u32), + num_var_map: VarMap, + + ndarray_float: Type, + ndarray_float_2d: Type, + ndarray_num_ty: Type, + + float_or_ndarray_ty: (Type, u32), + float_or_ndarray_var_map: VarMap, + + num_or_ndarray_ty: (Type, u32), + num_or_ndarray_var_map: VarMap, +} + +impl<'a> BuiltinBuilder<'a> { + fn new(unifier: &'a mut Unifier, primitives: &'a PrimitiveStore) -> BuiltinBuilder<'a> { + let PrimitiveStore { + int32, + int64, + uint32, + uint64, + float, + bool: boolean, + ndarray, + option, + .. + } = *primitives; + + // Option-related + let (is_some_ty, unwrap_ty, option_tvar) = + if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() { + ( + *fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(), + *fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(), + (*params.iter().next().unwrap().1, *params.iter().next().unwrap().0), + ) + } else { + unreachable!() + }; + + let TypeEnum::TObj { fields: ndarray_fields, params: ndarray_params, .. } = + &*unifier.get_ty(ndarray) + else { unreachable!() }; + let ndarray_dtype_tvar = + ndarray_params.iter().next().map(|(var_id, ty)| (*ty, *var_id)).unwrap(); + let ndarray_ndims_tvar = + ndarray_params.iter().nth(1).map(|(var_id, ty)| (*ty, *var_id)).unwrap(); + let ndarray_copy_ty = + *ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap(); + let ndarray_fill_ty = + *ndarray_fields.get(&PrimDef::NDArrayFill.simple_name().into()).unwrap(); - let TypeEnum::TObj { fields: ndarray_fields, params: ndarray_params, .. } = - &*unifier.get_ty(primitives.ndarray) - else { - unreachable!() - }; + let num_ty = unifier.get_fresh_var_with_range( + &[int32, int64, float, boolean, uint32, uint64], + Some("N".into()), + None, + ); + let num_var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect(); - let (ndarray_dtype_ty, ndarray_dtype_var_id) = - ndarray_params.iter().next().map(|(var_id, ty)| (*ty, *var_id)).unwrap(); - let (ndarray_ndims_ty, ndarray_ndims_var_id) = - ndarray_params.iter().nth(1).map(|(var_id, ty)| (*ty, *var_id)).unwrap(); - let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap(); - let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap(); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None); + let ndarray_float_2d = { + let value = match primitives.size_t { + 64 => SymbolValue::U64(2u64), + 32 => SymbolValue::U32(2u32), + _ => unreachable!(), + }; + let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None }); - let top_level_def_list = vec![ - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimDef::Int32.id(), - None, - "int32".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimDef::Int64.id(), - None, - "int64".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimDef::Float.id(), - None, - "float".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimDef::Bool.id(), - None, - "bool".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimDef::None.id(), - None, - "none".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimDef::Range.id(), - None, - "range".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimDef::Str.id(), - None, - "str".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelDef::Class { - name: "Exception".into(), - object_id: PrimDef::Exception.id(), - type_vars: Vec::default(), - fields: exception_fields, - methods: Vec::default(), - ancestors: vec![], - constructor: None, - resolver: None, - loc: None, - })), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimDef::UInt32.id(), - None, - "uint32".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimDef::UInt64.id(), - None, - "uint64".into(), - None, - None, - ))), - Arc::new(RwLock::new({ - TopLevelDef::Class { - name: "Option".into(), - object_id: PrimDef::Option.id(), - type_vars: vec![option_ty_var], + make_ndarray_ty(unifier, primitives, Some(float), Some(ndims)) + }; + + let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.0), None); + let float_or_ndarray_ty = + unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let float_or_ndarray_var_map: VarMap = + vec![(float_or_ndarray_ty.1, float_or_ndarray_ty.0)].into_iter().collect(); + + let num_or_ndarray_ty = + unifier.get_fresh_var_with_range(&[num_ty.0, ndarray_num_ty], Some("T".into()), None); + let num_or_ndarray_var_map: VarMap = + vec![(num_ty.1, num_ty.0), (num_or_ndarray_ty.1, num_or_ndarray_ty.0)] + .into_iter() + .collect(); + + let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 }); + + BuiltinBuilder { + unifier, + primitives, + + is_some_ty, + unwrap_ty, + option_tvar, + + ndarray_dtype_tvar, + ndarray_ndims_tvar, + ndarray_copy_ty, + ndarray_fill_ty, + + list_int32, + + num_ty, + num_var_map, + + ndarray_float, + ndarray_float_2d, + ndarray_num_ty, + + float_or_ndarray_ty, + float_or_ndarray_var_map, + + num_or_ndarray_ty, + num_or_ndarray_var_map, + } + } + + /// Construct every function from every [`PrimDef`], in the order of [`PrimDef`]'s definition. + fn build_all_builtins(&mut self) -> Vec { + PrimDef::iter().map(|prim| self.build_builtin_of_prim(prim)).collect_vec() + } + + /// Build the [`TopLevelDef`] associated of a [`PrimDef`]. + fn build_builtin_of_prim(&mut self, prim: PrimDef) -> TopLevelDef { + let tld = match prim { + PrimDef::Int32 + | PrimDef::Int64 + | PrimDef::UInt32 + | PrimDef::UInt64 + | PrimDef::Float + | PrimDef::Bool + | PrimDef::Str + | PrimDef::Range + | PrimDef::None => Self::build_simple_primitive_class(prim), + + PrimDef::Exception => self.build_exception_class_related(prim), + + PrimDef::Option + | PrimDef::OptionIsSome + | PrimDef::OptionIsNone + | PrimDef::OptionUnwrap + | PrimDef::FunSome => self.build_option_class_related(prim), + + PrimDef::NDArray | PrimDef::NDArrayCopy | PrimDef::NDArrayFill => { + self.build_ndarray_class_related(prim) + } + + PrimDef::FunInt32 + | PrimDef::FunInt64 + | PrimDef::FunUInt32 + | PrimDef::FunUInt64 + | PrimDef::FunFloat + | PrimDef::FunBool => self.build_cast_function(prim), + + PrimDef::FunNpNDArray + | PrimDef::FunNpEmpty + | PrimDef::FunNpZeros + | PrimDef::FunNpOnes => self.build_ndarray_from_shape_factory_function(prim), + + PrimDef::FunNpArray + | PrimDef::FunNpFull + | PrimDef::FunNpEye + | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), + + PrimDef::FunRange => self.build_range_function(), + PrimDef::FunStr => self.build_str_function(), + + PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { + self.build_ceil_floor_function(prim) + } + + PrimDef::FunAbs => self.build_abs_function(), + + PrimDef::FunRound | PrimDef::FunRound64 => self.build_round_function(prim), + + PrimDef::FunNpFloor | PrimDef::FunNpCeil => self.build_np_ceil_floor_function(prim), + + PrimDef::FunNpRound => self.build_np_round_function(), + + PrimDef::FunLen => self.build_len_function(), + + PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim), + + PrimDef::FunNpMin | PrimDef::FunNpMax => self.build_np_min_max_function(prim), + + PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => { + self.build_np_minimum_maximum_function(prim) + } + + PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim), + + PrimDef::FunNpSin + | PrimDef::FunNpCos + | PrimDef::FunNpTan + | PrimDef::FunNpArcsin + | PrimDef::FunNpArccos + | PrimDef::FunNpArctan + | PrimDef::FunNpSinh + | PrimDef::FunNpCosh + | PrimDef::FunNpTanh + | PrimDef::FunNpArcsinh + | PrimDef::FunNpArccosh + | PrimDef::FunNpArctanh + | PrimDef::FunNpExp + | PrimDef::FunNpExp2 + | PrimDef::FunNpExpm1 + | PrimDef::FunNpLog + | PrimDef::FunNpLog2 + | PrimDef::FunNpLog10 + | PrimDef::FunNpSqrt + | PrimDef::FunNpCbrt + | PrimDef::FunNpFabs + | PrimDef::FunNpRint + | PrimDef::FunSpSpecErf + | PrimDef::FunSpSpecErfc + | PrimDef::FunSpSpecGamma + | PrimDef::FunSpSpecGammaln + | PrimDef::FunSpSpecJ0 + | PrimDef::FunSpSpecJ1 => self.build_np_sp_float_or_ndarray_1ary_function(prim), + + PrimDef::FunNpArctan2 + | PrimDef::FunNpCopysign + | PrimDef::FunNpFmax + | PrimDef::FunNpFmin + | PrimDef::FunNpLdExp + | PrimDef::FunNpHypot + | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), + }; + + if cfg!(debug_assertions) { + // Sanity checks on the constructed [`TopLevelDef`] + + match (&tld, prim.details()) { + ( + TopLevelDef::Class { name, object_id, .. }, + PrimDefDetails::PrimClass { name: exp_name }, + ) => { + let exp_object_id = prim.id(); + assert_eq!(name, &exp_name.into()); + assert_eq!(object_id, &exp_object_id); + } + ( + TopLevelDef::Function { name, simple_name, .. }, + PrimDefDetails::PrimFunction { name: exp_name, simple_name: exp_simple_name }, + ) => { + assert_eq!(name, exp_name); + assert_eq!(simple_name, &exp_simple_name.into()); + } + _ => { + panic!("Class/function variant of the constructed TopLevelDef of PrimDef {prim:?} is different than what is defined by {prim:?}") + } + } + } + + tld + } + + /// Build "simple" primitive classes. + fn build_simple_primitive_class(prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[ + PrimDef::Int32, + PrimDef::Int64, + PrimDef::UInt32, + PrimDef::UInt64, + PrimDef::Float, + PrimDef::Bool, + PrimDef::Str, + PrimDef::Range, + PrimDef::None, + ], + ); + + TopLevelComposer::make_top_level_class_def(prim.id(), None, prim.name().into(), None, None) + } + + /// Build the class `Exception` and its associated methods. + fn build_exception_class_related(&self, prim: PrimDef) -> TopLevelDef { + // NOTE: currently only contains the class `Exception` + debug_assert_prim_is_allowed(prim, &[PrimDef::Exception]); + + let PrimitiveStore { int32, int64, str, .. } = *self.primitives; + + match prim { + PrimDef::Exception => { + let exception_fields: Vec<(StrRef, Type, bool)> = vec![ + ("__name__".into(), int32, true), + ("__file__".into(), str, true), + ("__line__".into(), int32, true), + ("__col__".into(), int32, true), + ("__func__".into(), str, true), + ("__message__".into(), str, true), + ("__param0__".into(), int64, true), + ("__param1__".into(), int64, true), + ("__param2__".into(), int64, true), + ]; + + TopLevelDef::Class { + name: prim.name().into(), + object_id: prim.id(), + type_vars: Vec::default(), + fields: exception_fields, + methods: Vec::default(), + ancestors: vec![], + constructor: None, + resolver: None, + loc: None, + } + } + _ => unreachable!(), + } + } + + /// Build the class `Option`, its associated methods and the function `Some()`. + fn build_option_class_related(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[ + PrimDef::Option, + PrimDef::OptionIsSome, + PrimDef::OptionIsNone, + PrimDef::OptionUnwrap, + PrimDef::FunSome, + ], + ); + + match prim { + PrimDef::Option => TopLevelDef::Class { + name: prim.name().into(), + object_id: prim.id(), + type_vars: vec![self.option_tvar.0], fields: vec![], methods: vec![ - ("is_some".into(), is_some_ty.0, PrimDef::OptionIsSome.id()), - ("is_none".into(), is_some_ty.0, PrimDef::OptionIsNone.id()), - ("unwrap".into(), unwrap_ty.0, PrimDef::OptionUnwrap.id()), + Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0), + Self::create_method(PrimDef::OptionIsNone, self.is_some_ty.0), + Self::create_method(PrimDef::OptionUnwrap, self.unwrap_ty.0), ], ancestors: vec![TypeAnnotation::CustomClass { - id: PrimDef::Option.id(), + id: prim.id(), params: Vec::default(), }], constructor: None, resolver: None, loc: None, - } - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "Option.is_some".into(), - simple_name: "is_some".into(), - signature: is_some_ty.0, - var_id: vec![option_ty_var_id], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, obj, _, _, generator| { - let expect_ty = obj.clone().unwrap().0; - let obj_val = - obj.unwrap().1.clone().to_basic_value_enum(ctx, generator, expect_ty)?; - let BasicValueEnum::PointerValue(ptr) = obj_val else { - unreachable!("option must be ptr") - }; + }, - Ok(Some(ctx.builder.build_is_not_null(ptr, "is_some").map(Into::into).unwrap())) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "Option.is_none".into(), - simple_name: "is_none".into(), - signature: is_some_ty.0, - var_id: vec![option_ty_var_id], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, obj, _, _, generator| { - let expect_ty = obj.clone().unwrap().0; - let obj_val = - obj.unwrap().1.clone().to_basic_value_enum(ctx, generator, expect_ty)?; - let BasicValueEnum::PointerValue(ptr) = obj_val else { - unreachable!("option must be ptr") - }; + PrimDef::OptionUnwrap => TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unwrap_ty.0, + var_id: vec![self.option_tvar.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::create_dummy(String::from( + "handled in gen_expr", + )))), + loc: None, + }, - Ok(Some(ctx.builder.build_is_null(ptr, "is_none").map(Into::into).unwrap())) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "Option.unwrap".into(), - simple_name: "unwrap".into(), - signature: unwrap_ty.0, - var_id: vec![option_ty_var_id], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::create_dummy(String::from( - "handled in gen_expr", - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Class { - name: "ndarray".into(), - object_id: PrimDef::NDArray.id(), - type_vars: vec![ndarray_dtype_ty, ndarray_ndims_ty], - fields: Vec::default(), - methods: vec![ - ("copy".into(), ndarray_copy_ty.0, PrimDef::NDArrayCopy.id()), - ("fill".into(), ndarray_fill_ty.0, PrimDef::NDArrayFill.id()), - ], - ancestors: Vec::default(), - constructor: None, - resolver: None, - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "ndarray.copy".into(), - simple_name: "copy".into(), - signature: ndarray_copy_ty.0, - var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, obj, fun, args, generator| { - gen_ndarray_copy(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "ndarray.fill".into(), - simple_name: "fill".into(), - signature: ndarray_fill_ty.0, - var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, obj, fun, args, generator| { - gen_ndarray_fill(ctx, &obj, fun, &args, generator)?; - Ok(None) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "int32".into(), - simple_name: "int32".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { - name: "n".into(), - ty: num_or_ndarray_ty.0, - default_value: None, - }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + PrimDef::OptionIsNone | PrimDef::OptionIsSome => TopLevelDef::Function { + name: prim.name().to_string(), + simple_name: prim.simple_name().into(), + signature: self.is_some_ty.0, + var_id: vec![self.option_tvar.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + move |ctx, obj, _, _, generator| { + let expect_ty = obj.clone().unwrap().0; + let obj_val = obj + .unwrap() + .1 + .clone() + .to_basic_value_enum(ctx, generator, expect_ty)?; + let BasicValueEnum::PointerValue(ptr) = obj_val else { + unreachable!("option must be ptr") + }; - Ok(Some(builtin_fns::call_int32(generator, ctx, (arg_ty, arg))?)) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "int64".into(), - simple_name: "int64".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { - name: "n".into(), - ty: num_or_ndarray_ty.0, - default_value: None, - }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + let returned_int = match prim { + PrimDef::OptionIsNone => { + ctx.builder.build_is_null(ptr, prim.simple_name()) + } + PrimDef::OptionIsSome => { + ctx.builder.build_is_not_null(ptr, prim.simple_name()) + } + _ => unreachable!(), + }; + Ok(Some(returned_int.map(Into::into).unwrap())) + }, + )))), + loc: None, + }, - Ok(Some(builtin_fns::call_int64(generator, ctx, (arg_ty, arg))?)) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "uint32".into(), - simple_name: "uint32".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { - name: "n".into(), - ty: num_or_ndarray_ty.0, - default_value: None, - }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_uint32(generator, ctx, (arg_ty, arg))?)) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "uint64".into(), - simple_name: "uint64".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { - name: "n".into(), - ty: num_or_ndarray_ty.0, - default_value: None, - }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_uint64(generator, ctx, (arg_ty, arg))?)) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "float".into(), - simple_name: "float".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { - name: "n".into(), - ty: num_or_ndarray_ty.0, - default_value: None, - }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_float(generator, ctx, (arg_ty, arg))?)) - }, - )))), - loc: None, - })), - create_fn_by_codegen( - unifier, - &VarMap::new(), - "np_ndarray", - ndarray_float, - // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a - // type variable - &[(list_int32, "shape")], - Box::new(|ctx, obj, fun, args, generator| { - gen_ndarray_empty(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }), - ), - create_fn_by_codegen( - unifier, - &VarMap::new(), - "np_empty", - ndarray_float, - // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a - // type variable - &[(list_int32, "shape")], - Box::new(|ctx, obj, fun, args, generator| { - gen_ndarray_empty(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }), - ), - create_fn_by_codegen( - unifier, - &VarMap::new(), - "np_zeros", - ndarray_float, - // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a - // type variable - &[(list_int32, "shape")], - Box::new(|ctx, obj, fun, args, generator| { - gen_ndarray_zeros(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }), - ), - create_fn_by_codegen( - unifier, - &VarMap::new(), - "np_ones", - ndarray_float, - // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a - // type variable - &[(list_int32, "shape")], - Box::new(|ctx, obj, fun, args, generator| { - gen_ndarray_ones(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }), - ), - { - let tv = unifier.get_fresh_var(Some("T".into()), None); - - create_fn_by_codegen( - unifier, - &[(tv.1, tv.0)].into_iter().collect(), - "np_full", - ndarray, - // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a - // type variable - &[(list_int32, "shape"), (tv.0, "fill_value")], - Box::new(|ctx, obj, fun, args, generator| { - gen_ndarray_full(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }), - ) - }, - { - let tv = unifier.get_fresh_var(Some("T".into()), None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_array".into(), - simple_name: "np_array".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { name: "object".into(), ty: tv.0, default_value: None }, - FuncArg { - name: "copy".into(), - ty: boolean, - default_value: Some(SymbolValue::Bool(true)), - }, - FuncArg { - name: "ndmin".into(), - ty: int32, - default_value: Some(SymbolValue::U32(0)), - }, - ], - ret: ndarray, - vars: VarMap::from([(tv.1, tv.0)]), + PrimDef::FunSome => TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "n".into(), + ty: self.option_tvar.0, + default_value: None, + }], + ret: self.primitives.option, + vars: VarMap::from([(self.option_tvar.1, self.option_tvar.0)]), })), - var_id: vec![tv.1], + var_id: vec![self.option_tvar.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + let alloca = generator + .gen_var_alloc(ctx, arg_val.get_type(), Some("alloca_some")) + .unwrap(); + ctx.builder.build_store(alloca, arg_val).unwrap(); + Ok(Some(alloca.into())) + }, + )))), + loc: None, + }, + + _ => { + unreachable!() + } + } + } + + /// Build the class `ndarray` and its associated methods. + fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill], + ); + + match prim { + PrimDef::NDArray => TopLevelDef::Class { + name: prim.name().into(), + object_id: prim.id(), + type_vars: vec![self.ndarray_dtype_tvar.0, self.ndarray_ndims_tvar.0], + fields: Vec::default(), + methods: vec![ + Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0), + Self::create_method(PrimDef::NDArrayFill, self.ndarray_fill_ty.0), + ], + ancestors: Vec::default(), + constructor: None, + resolver: None, + loc: None, + }, + + PrimDef::NDArrayCopy => TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.ndarray_copy_ty.0, + var_id: vec![self.ndarray_dtype_tvar.1, self.ndarray_ndims_tvar.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( |ctx, obj, fun, args, generator| { - gen_ndarray_array(ctx, &obj, fun, &args, generator) + gen_ndarray_copy(ctx, &obj, fun, &args, generator) .map(|val| Some(val.as_basic_value_enum())) }, )))), loc: None, - })) - }, - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_eye".into(), - simple_name: "np_eye".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { name: "N".into(), ty: int32, default_value: None }, - // TODO(Derppening): Default values current do not work? - FuncArg { - name: "M".into(), - ty: int32, - default_value: Some(SymbolValue::OptionNone), + }, + + PrimDef::NDArrayFill => TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.ndarray_fill_ty.0, + var_id: vec![self.ndarray_dtype_tvar.1, self.ndarray_ndims_tvar.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, fun, args, generator| { + gen_ndarray_fill(ctx, &obj, fun, &args, generator)?; + Ok(None) }, - FuncArg { - name: "k".into(), - ty: int32, - default_value: Some(SymbolValue::I32(0)), - }, - ], - ret: ndarray_float_2d, - vars: VarMap::default(), + )))), + loc: None, + }, + + _ => unreachable!(), + } + } + + /// Build functions that cast a numeric primitive to another numeric primitive, including booleans. + fn build_cast_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[ + PrimDef::FunInt32, + PrimDef::FunInt64, + PrimDef::FunUInt32, + PrimDef::FunUInt64, + PrimDef::FunFloat, + PrimDef::FunBool, + ], + ); + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "n".into(), + ty: self.num_or_ndarray_ty.0, + default_value: None, + }], + ret: self.num_or_ndarray_ty.0, + vars: self.num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, obj, fun, args, generator| { - gen_ndarray_eye(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) + move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + + let func = match prim { + PrimDef::FunInt32 => builtin_fns::call_int32, + PrimDef::FunInt64 => builtin_fns::call_int64, + PrimDef::FunUInt32 => builtin_fns::call_uint32, + PrimDef::FunUInt64 => builtin_fns::call_uint64, + PrimDef::FunFloat => builtin_fns::call_float, + PrimDef::FunBool => builtin_fns::call_bool, + _ => unreachable!(), + }; + Ok(Some(func(generator, ctx, (arg_ty, arg))?)) }, )))), loc: None, - })), + } + } + + /// Build the functions `round()` and `round64()`. + fn build_round_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunRound, PrimDef::FunRound64]); + + let float = self.primitives.float; + + let size_variant = match prim { + PrimDef::FunRound => SizeVariant::Bits32, + PrimDef::FunRound64 => SizeVariant::Bits64, + _ => unreachable!(), + }; + + let common_ndim = self.unifier.get_fresh_const_generic_var( + self.primitives.usize(), + Some("N".into()), + None, + ); + + // The size variant of the function determines the size of the returned int. + let int_sized = size_variant.of_int(self.primitives); + + let ndarray_int_sized = + make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.0)); + let ndarray_float = + make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.0)); + + let p0_ty = + self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let ret_ty = self.unifier.get_fresh_var_with_range( + &[int_sized, ndarray_int_sized], + Some("R".into()), + None, + ); + create_fn_by_codegen( - unifier, - &VarMap::new(), - "np_identity", - ndarray_float_2d, - &[(int32, "n")], - Box::new(|ctx, obj, fun, args, generator| { - gen_ndarray_identity(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }), - ), - { - let common_ndim = - unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None); - let ndarray_int32 = - make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); - let ndarray_float = - make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - - let p0_ty = - unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); - let ret_ty = - unifier.get_fresh_var_with_range(&[int32, ndarray_int32], Some("R".into()), None); - - create_fn_by_codegen( - unifier, - &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - "round", - ret_ty.0, - &[(p0_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_round( - generator, - ctx, - (arg_ty, arg), - ctx.primitives.int32, - )?)) - }), - ) - }, - { - let common_ndim = - unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None); - let ndarray_int64 = - make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); - let ndarray_float = - make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - - let p0_ty = - unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); - let ret_ty = - unifier.get_fresh_var_with_range(&[int64, ndarray_int64], Some("R".into()), None); - - create_fn_by_codegen( - unifier, - &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - "round64", - ret_ty.0, - &[(p0_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_round( - generator, - ctx, - (arg_ty, arg), - ctx.primitives.int64, - )?)) - }), - ) - }, - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_round", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { + self.unifier, + &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), + prim.name(), + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(move |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_numpy_round(generator, ctx, (arg_ty, arg))?)) + let ret_elem_ty = size_variant.of_int(&ctx.primitives); + Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) }), - ), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "range".into(), - simple_name: "range".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + ) + } + + /// Build the functions `ceil()` and `floor()` and their 64 bit variants. + fn build_ceil_floor_function(&mut self, prim: PrimDef) -> TopLevelDef { + #[derive(Clone, Copy)] + enum Kind { + Floor, + Ceil, + } + + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunFloor, PrimDef::FunFloor64, PrimDef::FunCeil, PrimDef::FunCeil64], + ); + + let (size_variant, kind) = { + match prim { + PrimDef::FunFloor => (SizeVariant::Bits32, Kind::Floor), + PrimDef::FunFloor64 => (SizeVariant::Bits64, Kind::Floor), + PrimDef::FunCeil => (SizeVariant::Bits32, Kind::Ceil), + PrimDef::FunCeil64 => (SizeVariant::Bits64, Kind::Ceil), + _ => unreachable!(), + } + }; + + let float = self.primitives.float; + + let common_ndim = self.unifier.get_fresh_const_generic_var( + self.primitives.usize(), + Some("N".into()), + None, + ); + + let ndarray_float = + make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.0)); + + // The size variant of the function determines the type of int returned + let int_sized = size_variant.of_int(self.primitives); + let ndarray_int_sized = + make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.0)); + + let p0_ty = + self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + + let ret_ty = self.unifier.get_fresh_var_with_range( + &[int_sized, ndarray_int_sized], + Some("R".into()), + None, + ); + + create_fn_by_codegen( + self.unifier, + &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), + prim.name(), + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + + let ret_elem_ty = size_variant.of_int(&ctx.primitives); + let func = match kind { + Kind::Ceil => builtin_fns::call_ceil, + Kind::Floor => builtin_fns::call_floor, + }; + Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) + }), + ) + } + + /// Build ndarray factory functions that only take in an argument `shape` of type `list[int32]` and return an ndarray. + fn build_ndarray_from_shape_factory_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpNDArray, PrimDef::FunNpEmpty, PrimDef::FunNpZeros, PrimDef::FunNpOnes], + ); + + create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_float, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(self.list_int32, "shape")], + Box::new(move |ctx, obj, fun, args, generator| { + let func = match prim { + PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty, + PrimDef::FunNpZeros => gen_ndarray_zeros, + PrimDef::FunNpOnes => gen_ndarray_ones, + _ => unreachable!(), + }; + func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum())) + }), + ) + } + + /// Build ndarray factory functions that do not fit in any other `build_ndarray_*_factory_function` categories in [`BuiltinBuilder`]. + /// + /// See also [`BuiltinBuilder::build_ndarray_from_shape_factory_function`]. + fn build_ndarray_other_factory_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpArray, PrimDef::FunNpFull, PrimDef::FunNpEye, PrimDef::FunNpIdentity], + ); + + let PrimitiveStore { int32, bool, ndarray, .. } = *self.primitives; + + match prim { + PrimDef::FunNpArray => { + let tv = self.unifier.get_fresh_var(Some("T".into()), None); + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { name: "object".into(), ty: tv.0, default_value: None }, + FuncArg { + name: "copy".into(), + ty: bool, + default_value: Some(SymbolValue::Bool(true)), + }, + FuncArg { + name: "ndmin".into(), + ty: int32, + default_value: Some(SymbolValue::U32(0)), + }, + ], + ret: ndarray, + vars: VarMap::from([(tv.1, tv.0)]), + })), + var_id: vec![tv.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, fun, args, generator| { + gen_ndarray_array(ctx, &obj, fun, &args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }, + )))), + loc: None, + } + } + PrimDef::FunNpFull => { + let tv = self.unifier.get_fresh_var(Some("T".into()), None); + + create_fn_by_codegen( + self.unifier, + &[(tv.1, tv.0)].into_iter().collect(), + prim.name(), + self.primitives.ndarray, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(self.list_int32, "shape"), (tv.0, "fill_value")], + Box::new(move |ctx, obj, fun, args, generator| { + gen_ndarray_full(ctx, &obj, fun, &args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ) + } + PrimDef::FunNpEye => { + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { name: "N".into(), ty: int32, default_value: None }, + // TODO(Derppening): Default values current do not work? + FuncArg { + name: "M".into(), + ty: int32, + default_value: Some(SymbolValue::OptionNone), + }, + FuncArg { + name: "k".into(), + ty: int32, + default_value: Some(SymbolValue::I32(0)), + }, + ], + ret: self.ndarray_float_2d, + vars: VarMap::default(), + })), + var_id: Vec::default(), + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, fun, args, generator| { + gen_ndarray_eye(ctx, &obj, fun, &args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }, + )))), + loc: None, + } + } + PrimDef::FunNpIdentity => create_fn_by_codegen( + self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_float_2d, + &[(int32, "n")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_identity(ctx, &obj, fun, &args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ), + _ => unreachable!(), + } + } + + /// Build the `range()` function. + fn build_range_function(&mut self) -> TopLevelDef { + let prim = PrimDef::FunRange; + + let PrimitiveStore { int32, range, .. } = *self.primitives; + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ FuncArg { name: "start".into(), ty: int32, default_value: None }, FuncArg { @@ -1037,13 +1248,21 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built }, )))), loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "str".into(), - simple_name: "str".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "s".into(), ty: string, default_value: None }], - ret: string, + } + } + + /// Build the `str()` function. + fn build_str_function(&mut self) -> TopLevelDef { + let prim = PrimDef::FunStr; + + let str = self.primitives.str; + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "s".into(), ty: str, default_value: None }], + ret: str, vars: VarMap::default(), })), var_id: Vec::default(), @@ -1057,517 +1276,299 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built }, )))), loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "bool".into(), - simple_name: "bool".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { - name: "n".into(), - ty: num_or_ndarray_ty.0, - default_value: None, - }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), + } + } + + /// Build functions `np_ceil()` and `np_floor()`. + fn build_np_ceil_floor_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpCeil, PrimDef::FunNpFloor]); + + create_fn_by_codegen( + self.unifier, + &self.float_or_ndarray_var_map, + prim.name(), + self.float_or_ndarray_ty.0, + &[(self.float_or_ndarray_ty.0, "n")], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + + let func = match prim { + PrimDef::FunNpCeil => builtin_fns::call_ceil, + PrimDef::FunNpFloor => builtin_fns::call_floor, + _ => unreachable!(), + }; + Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) + }), + ) + } + + /// Build the `np_round()` function. + fn build_np_round_function(&mut self) -> TopLevelDef { + let prim = PrimDef::FunNpRound; + + create_fn_by_codegen( + self.unifier, + &self.float_or_ndarray_var_map, + prim.name(), + self.float_or_ndarray_ty.0, + &[(self.float_or_ndarray_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + Ok(Some(builtin_fns::call_numpy_round(generator, ctx, (arg_ty, arg))?)) + }), + ) + } + + /// Build the `len()` function. + fn build_len_function(&mut self) -> TopLevelDef { + let prim = PrimDef::FunLen; + + let PrimitiveStore { uint64, int32, .. } = *self.primitives; + + let tvar = self.unifier.get_fresh_var(Some("L".into()), None); + let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.0 }); + let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None); + let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.0), Some(ndims.0)); + + let arg_ty = self.unifier.get_fresh_var_with_range( + &[list, ndarray, self.primitives.range], + Some("I".into()), + None, + ); + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }], + ret: int32, + vars: vec![(tvar.1, tvar.0), (arg_ty.1, arg_ty.0)].into_iter().collect(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { + move |ctx, _, fun, args, generator| { + let range_ty = ctx.primitives.range; let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_bool(generator, ctx, (arg_ty, arg))?)) - }, - )))), - loc: None, - })), - { - let common_ndim = - unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None); - let ndarray_int32 = - make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); - let ndarray_float = - make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - - let p0_ty = - unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); - let ret_ty = - unifier.get_fresh_var_with_range(&[int32, ndarray_int32], Some("R".into()), None); - - create_fn_by_codegen( - unifier, - &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - "floor", - ret_ty.0, - &[(p0_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_floor( - generator, - ctx, - (arg_ty, arg), - ctx.primitives.int32, - )?)) - }), - ) - }, - { - let common_ndim = - unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None); - let ndarray_int64 = - make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); - let ndarray_float = - make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - - let p0_ty = - unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); - let ret_ty = - unifier.get_fresh_var_with_range(&[int64, ndarray_int64], Some("R".into()), None); - - create_fn_by_codegen( - unifier, - &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - "floor64", - ret_ty.0, - &[(p0_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_floor( - generator, - ctx, - (arg_ty, arg), - ctx.primitives.int64, - )?)) - }), - ) - }, - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_floor", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_floor( - generator, - ctx, - (arg_ty, arg), - ctx.primitives.float, - )?)) - }), - ), - { - let common_ndim = - unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None); - let ndarray_int32 = - make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); - let ndarray_float = - make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - - let p0_ty = - unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); - let ret_ty = - unifier.get_fresh_var_with_range(&[int32, ndarray_int32], Some("R".into()), None); - - create_fn_by_codegen( - unifier, - &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - "ceil", - ret_ty.0, - &[(p0_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_ceil( - generator, - ctx, - (arg_ty, arg), - ctx.primitives.int32, - )?)) - }), - ) - }, - { - let common_ndim = - unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None); - let ndarray_int64 = - make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); - let ndarray_float = - make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - - let p0_ty = - unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); - let ret_ty = - unifier.get_fresh_var_with_range(&[int64, ndarray_int64], Some("R".into()), None); - - create_fn_by_codegen( - unifier, - &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - "ceil64", - ret_ty.0, - &[(p0_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_ceil( - generator, - ctx, - (arg_ty, arg), - ctx.primitives.int64, - )?)) - }), - ) - }, - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_ceil", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_ceil( - generator, - ctx, - (arg_ty, arg), - ctx.primitives.float, - )?)) - }), - ), - Arc::new(RwLock::new({ - let tvar = unifier.get_fresh_var(Some("L".into()), None); - let list = unifier.add_ty(TypeEnum::TList { ty: tvar.0 }); - let ndims = - unifier.get_fresh_const_generic_var(primitives.uint64, Some("N".into()), None); - let ndarray = make_ndarray_ty(unifier, primitives, Some(tvar.0), Some(ndims.0)); - - let arg_ty = unifier.get_fresh_var_with_range( - &[list, ndarray, primitives.range], - Some("I".into()), - None, - ); - TopLevelDef::Function { - name: "len".into(), - simple_name: "len".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }], - ret: int32, - vars: vec![(tvar.1, tvar.0), (arg_ty.1, arg_ty.0)].into_iter().collect(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let range_ty = ctx.primitives.range; - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(if ctx.unifier.unioned(arg_ty, range_ty) { - let arg = - RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); - let (start, end, step) = destructure_range(ctx, arg); - Some( - calculate_len_for_slice_range(generator, ctx, start, end, step) - .into(), - ) - } else { - match &*ctx.unifier.get_ty_immutable(arg_ty) { - TypeEnum::TList { .. } => { - let int32 = ctx.ctx.i32_type(); - let zero = int32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, int32.const_int(1, false)], - None, - ) - .into_int_value(); - if len.get_type().get_bit_width() == 32 { - Some(len.into()) - } else { - Some( - ctx.builder - .build_int_truncate(len, int32, "len2i32") - .map(Into::into) - .unwrap(), - ) - } - } - TypeEnum::TObj { obj_id, .. } - if *obj_id == PrimDef::NDArray.id() => - { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - let arg = NDArrayValue::from_ptr_val( + Ok(if ctx.unifier.unioned(arg_ty, range_ty) { + let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); + let (start, end, step) = destructure_range(ctx, arg); + Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into()) + } else { + match &*ctx.unifier.get_ty_immutable(arg_ty) { + TypeEnum::TList { .. } => { + let int32 = ctx.ctx.i32_type(); + let zero = int32.const_zero(); + let len = ctx + .build_gep_and_load( arg.into_pointer_value(), - llvm_usize, + &[zero, int32.const_int(1, false)], None, - ); - - let ndims = arg.dim_sizes().size(ctx, generator); - ctx.make_assert( - generator, + ) + .into_int_value(); + if len.get_type().get_bit_width() == 32 { + Some(len.into()) + } else { + Some( ctx.builder - .build_int_compare( - IntPredicate::NE, - ndims, - llvm_usize.const_zero(), - "", - ) + .build_int_truncate(len, int32, "len2i32") + .map(Into::into) .unwrap(), - "0:TypeError", - "len() of unsized object", - [None, None, None], - ctx.current_loc, - ); - - let len = unsafe { - arg.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - }; - - if len.get_type().get_bit_width() == 32 { - Some(len.into()) - } else { - Some( - ctx.builder - .build_int_truncate(len, llvm_i32, "len") - .map(Into::into) - .unwrap(), - ) - } + ) } - _ => unreachable!(), } - }) - }, - )))), - loc: None, - } - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "min".into(), - simple_name: "min".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let arg = NDArrayValue::from_ptr_val( + arg.into_pointer_value(), + llvm_usize, + None, + ); + + let ndims = arg.dim_sizes().size(ctx, generator); + ctx.make_assert( + generator, + ctx.builder + .build_int_compare( + IntPredicate::NE, + ndims, + llvm_usize.const_zero(), + "", + ) + .unwrap(), + "0:TypeError", + &format!("{name}() of unsized object", name = prim.name()), + [None, None, None], + ctx.current_loc, + ); + + let len = unsafe { + arg.dim_sizes().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_zero(), + None, + ) + }; + + if len.get_type().get_bit_width() == 32 { + Some(len.into()) + } else { + Some( + ctx.builder + .build_int_truncate(len, llvm_i32, "len") + .map(Into::into) + .unwrap(), + ) + } + } + _ => unreachable!(), + } + }) + }, + )))), + loc: None, + } + } + + /// Build the functions `min()` and `max()`. + fn build_min_max_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunMin, PrimDef::FunMax]); + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ - FuncArg { name: "m".into(), ty: num_ty.0, default_value: None }, - FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }, + FuncArg { name: "m".into(), ty: self.num_ty.0, default_value: None }, + FuncArg { name: "n".into(), ty: self.num_ty.0, default_value: None }, ], - ret: num_ty.0, - vars: num_var_map.clone(), + ret: self.num_ty.0, + vars: self.num_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { + move |ctx, _, fun, args, generator| { let m_ty = fun.0.args[0].ty; let n_ty = fun.0.args[1].ty; let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?; let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; - Ok(Some(builtin_fns::call_min(ctx, (m_ty, m_val), (n_ty, n_val)))) + let func = match prim { + PrimDef::FunMin => builtin_fns::call_min, + PrimDef::FunMax => builtin_fns::call_max, + _ => unreachable!(), + }; + Ok(Some(func(ctx, (m_ty, m_val), (n_ty, n_val)))) }, )))), loc: None, - })), - { - let ret_ty = unifier.get_fresh_var(Some("R".into()), None); - let var_map = num_or_ndarray_var_map - .clone() - .into_iter() - .chain(once((ret_ty.1, ret_ty.0))) - .collect::>(); + } + } - create_fn_by_codegen( - unifier, - &var_map, - "np_min", - ret_ty.0, - &[(float_or_ndarray_ty.0, "a")], - Box::new(|ctx, _, fun, args, generator| { - let a_ty = fun.0.args[0].ty; - let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; + /// Build the functions `np_min()` and `np_max()`. + fn build_np_min_max_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMin, PrimDef::FunNpMax]); - Ok(Some(builtin_fns::call_numpy_min(generator, ctx, (a_ty, a))?)) - }), - ) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); + let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None); + let var_map = self + .num_or_ndarray_var_map + .clone() + .into_iter() + .chain(once((ret_ty.1, ret_ty.0))) + .collect::>(); - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_minimum".into(), - simple_name: "np_minimum".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty - .iter() - .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) - .collect(), - ret: ret_ty.0, - vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x2_ty = fun.0.args[1].ty; - let x1_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_val = - args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + create_fn_by_codegen( + self.unifier, + &var_map, + prim.name(), + ret_ty.0, + &[(self.float_or_ndarray_ty.0, "a")], + Box::new(move |ctx, _, fun, args, generator| { + let a_ty = fun.0.args[0].ty; + let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; - Ok(Some(builtin_fns::call_numpy_minimum( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - }, - )))), - loc: None, - })) - }, - Arc::new(RwLock::new(TopLevelDef::Function { - name: "max".into(), - simple_name: "max".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { name: "m".into(), ty: num_ty.0, default_value: None }, - FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }, - ], - ret: num_ty.0, - vars: num_var_map.clone(), + let func = match prim { + PrimDef::FunNpMin => builtin_fns::call_numpy_min, + PrimDef::FunNpMax => builtin_fns::call_numpy_max, + _ => unreachable!(), + }; + + Ok(Some(func(generator, ctx, (a_ty, a))?)) + }), + ) + } + + /// Build the functions `np_minimum()` and `np_maximum()`. + fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]); + + let x1_ty = self.new_type_or_ndarray_ty(self.num_ty.0); + let x2_ty = self.new_type_or_ndarray_ty(self.num_ty.0); + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = self.unifier.get_fresh_var(None, None); + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), + ret: ret_ty.0, + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), - var_id: Vec::default(), + var_id: vec![x1_ty.1, x2_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let m_ty = fun.0.args[0].ty; - let n_ty = fun.0.args[1].ty; - let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?; - let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; + move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x2_ty = fun.0.args[1].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_max(ctx, (m_ty, m_val), (n_ty, n_val)))) + let func = match prim { + PrimDef::FunNpMinimum => builtin_fns::call_numpy_minimum, + PrimDef::FunNpMaximum => builtin_fns::call_numpy_maximum, + _ => unreachable!(), + }; + + Ok(Some(func(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) }, )))), loc: None, - })), - { - let ret_ty = unifier.get_fresh_var(Some("R".into()), None); - let var_map = num_or_ndarray_var_map - .clone() - .into_iter() - .chain(once((ret_ty.1, ret_ty.0))) - .collect::>(); + } + } - create_fn_by_codegen( - unifier, - &var_map, - "np_max", - ret_ty.0, - &[(float_or_ndarray_ty.0, "a")], - Box::new(|ctx, _, fun, args, generator| { - let a_ty = fun.0.args[0].ty; - let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; + /// Build the `abs()` function. + fn build_abs_function(&mut self) -> TopLevelDef { + let prim = PrimDef::FunAbs; - Ok(Some(builtin_fns::call_numpy_max(generator, ctx, (a_ty, a))?)) - }), - ) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_maximum".into(), - simple_name: "np_maximum".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty - .iter() - .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) - .collect(), - ret: ret_ty.0, - vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x2_ty = fun.0.args[1].ty; - let x1_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_val = - args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_maximum( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - }, - )))), - loc: None, - })) - }, - Arc::new(RwLock::new(TopLevelDef::Function { - name: "abs".into(), - simple_name: "abs".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { name: "n".into(), - ty: num_or_ndarray_ty.0, + ty: self.num_or_ndarray_ty.0, default_value: None, }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), + ret: self.num_or_ndarray_ty.0, + vars: self.num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -1582,727 +1583,220 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built }, )))), loc: None, - })), + } + } + + /// Build numpy functions that take in a float and return a boolean. + fn build_np_float_to_bool_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpIsInf, PrimDef::FunNpIsNan]); + + let PrimitiveStore { bool, float, .. } = *self.primitives; + create_fn_by_codegen( - unifier, + self.unifier, &VarMap::new(), - "np_isnan", - boolean, + prim.name(), + bool, &[(float, "x")], - Box::new(|ctx, _, fun, args, generator| { + Box::new(move |ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_isnan(generator, ctx, (x_ty, x_val))?)) + let func = match prim { + PrimDef::FunNpIsInf => builtin_fns::call_numpy_isinf, + PrimDef::FunNpIsNan => builtin_fns::call_numpy_isnan, + _ => unreachable!(), + }; + + Ok(Some(func(generator, ctx, (x_ty, x_val))?)) }), - ), + ) + } + + /// Build 1-ary numpy/scipy functions that take in a float or an ndarray and return a value of the same type as the input. + fn build_np_sp_float_or_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[ + PrimDef::FunNpSin, + PrimDef::FunNpCos, + PrimDef::FunNpTan, + PrimDef::FunNpArcsin, + PrimDef::FunNpArccos, + PrimDef::FunNpArctan, + PrimDef::FunNpSinh, + PrimDef::FunNpCosh, + PrimDef::FunNpTanh, + PrimDef::FunNpArcsinh, + PrimDef::FunNpArccosh, + PrimDef::FunNpArctanh, + PrimDef::FunNpExp, + PrimDef::FunNpExp2, + PrimDef::FunNpExpm1, + PrimDef::FunNpLog, + PrimDef::FunNpLog2, + PrimDef::FunNpLog10, + PrimDef::FunNpSqrt, + PrimDef::FunNpCbrt, + PrimDef::FunNpFabs, + PrimDef::FunNpRint, + PrimDef::FunSpSpecErf, + PrimDef::FunSpSpecErfc, + PrimDef::FunSpSpecGamma, + PrimDef::FunSpSpecGammaln, + PrimDef::FunSpSpecJ0, + PrimDef::FunSpSpecJ1, + ], + ); + + // The parameter name of the sole input of this function. + // Usually this is just "x", but some functions have a different parameter name. + let arg_name = match prim { + PrimDef::FunSpSpecErf => "z", + _ => "x", + }; + create_fn_by_codegen( - unifier, - &VarMap::new(), - "np_isinf", - boolean, - &[(float, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; + self.unifier, + &self.float_or_ndarray_var_map, + prim.name(), + self.float_or_ndarray_ty.0, + &[(self.float_or_ndarray_ty.0, arg_name)], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_numpy_isinf(generator, ctx, (x_ty, x_val))?)) + let func = match prim { + PrimDef::FunNpSin => builtin_fns::call_numpy_sin, + PrimDef::FunNpCos => builtin_fns::call_numpy_cos, + PrimDef::FunNpTan => builtin_fns::call_numpy_tan, + + PrimDef::FunNpArcsin => builtin_fns::call_numpy_arcsin, + PrimDef::FunNpArccos => builtin_fns::call_numpy_arccos, + PrimDef::FunNpArctan => builtin_fns::call_numpy_arctan, + + PrimDef::FunNpSinh => builtin_fns::call_numpy_sinh, + PrimDef::FunNpCosh => builtin_fns::call_numpy_cosh, + PrimDef::FunNpTanh => builtin_fns::call_numpy_tanh, + + PrimDef::FunNpArcsinh => builtin_fns::call_numpy_arcsinh, + PrimDef::FunNpArccosh => builtin_fns::call_numpy_arccosh, + PrimDef::FunNpArctanh => builtin_fns::call_numpy_arctanh, + + PrimDef::FunNpExp => builtin_fns::call_numpy_exp, + PrimDef::FunNpExp2 => builtin_fns::call_numpy_exp2, + PrimDef::FunNpExpm1 => builtin_fns::call_numpy_expm1, + + PrimDef::FunNpLog => builtin_fns::call_numpy_log, + PrimDef::FunNpLog2 => builtin_fns::call_numpy_log2, + PrimDef::FunNpLog10 => builtin_fns::call_numpy_log10, + + PrimDef::FunNpSqrt => builtin_fns::call_numpy_sqrt, + PrimDef::FunNpCbrt => builtin_fns::call_numpy_cbrt, + + PrimDef::FunNpFabs => builtin_fns::call_numpy_fabs, + PrimDef::FunNpRint => builtin_fns::call_numpy_rint, + + PrimDef::FunSpSpecErf => builtin_fns::call_scipy_special_erf, + PrimDef::FunSpSpecErfc => builtin_fns::call_scipy_special_erfc, + + PrimDef::FunSpSpecGamma => builtin_fns::call_scipy_special_gamma, + PrimDef::FunSpSpecGammaln => builtin_fns::call_scipy_special_gammaln, + + PrimDef::FunSpSpecJ0 => builtin_fns::call_scipy_special_j0, + PrimDef::FunSpSpecJ1 => builtin_fns::call_scipy_special_j1, + + _ => unreachable!(), + }; + Ok(Some(func(generator, ctx, (arg_ty, arg_val))?)) }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_sin", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; + ) + } - Ok(Some(builtin_fns::call_numpy_sin(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_cos", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; + /// Build 2-ary numpy functions. The exact argument types of the two input arguments can be controlled. + fn build_np_2ary_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[ + PrimDef::FunNpArctan2, + PrimDef::FunNpCopysign, + PrimDef::FunNpFmax, + PrimDef::FunNpFmin, + PrimDef::FunNpLdExp, + PrimDef::FunNpHypot, + PrimDef::FunNpNextAfter, + ], + ); - Ok(Some(builtin_fns::call_numpy_cos(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_exp", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; + let PrimitiveStore { float, int32, .. } = *self.primitives; - Ok(Some(builtin_fns::call_numpy_exp(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_exp2", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; + // The argument types of the two input arguments are controlled here. + let (x1_ty, x2_ty) = match prim { + PrimDef::FunNpArctan2 + | PrimDef::FunNpCopysign + | PrimDef::FunNpFmax + | PrimDef::FunNpFmin + | PrimDef::FunNpHypot + | PrimDef::FunNpNextAfter => (float, float), + PrimDef::FunNpLdExp => (float, int32), + _ => unreachable!(), + }; - Ok(Some(builtin_fns::call_numpy_exp2(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_log", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; + let x1_ty = self.new_type_or_ndarray_ty(x1_ty); + let x2_ty = self.new_type_or_ndarray_ty(x2_ty); - Ok(Some(builtin_fns::call_numpy_log(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_log10", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = self.unifier.get_fresh_var(None, None); - Ok(Some(builtin_fns::call_numpy_log10(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_log2", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_log2(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_fabs", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_fabs(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_sqrt", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_sqrt(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_rint", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_rint(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_tan", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_tan(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_arcsin", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_arcsin(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_arccos", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_arccos(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_arctan", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_arctan(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_sinh", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_sinh(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_cosh", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_cosh(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_tanh", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_tanh(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_arcsinh", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_arcsinh(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_arccosh", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_arccosh(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_arctanh", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_arctanh(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_expm1", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_expm1(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_cbrt", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_cbrt(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "sp_spec_erf", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "z")], - Box::new(|ctx, _, fun, args, generator| { - let z_ty = fun.0.args[0].ty; - let z_val = args[0].1.clone().to_basic_value_enum(ctx, generator, z_ty)?; - - Ok(Some(builtin_fns::call_scipy_special_erf(generator, ctx, (z_ty, z_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "sp_spec_erfc", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let z_ty = fun.0.args[0].ty; - let z_val = args[0].1.clone().to_basic_value_enum(ctx, generator, z_ty)?; - - Ok(Some(builtin_fns::call_scipy_special_erfc(generator, ctx, (z_ty, z_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "sp_spec_gamma", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "z")], - Box::new(|ctx, _, fun, args, generator| { - let z_ty = fun.0.args[0].ty; - let z_val = args[0].1.clone().to_basic_value_enum(ctx, generator, z_ty)?; - - Ok(Some(builtin_fns::call_scipy_special_gamma(generator, ctx, (z_ty, z_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "sp_spec_gammaln", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_scipy_special_gammaln(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "sp_spec_j0", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let z_ty = fun.0.args[0].ty; - let z_val = args[0].1.clone().to_basic_value_enum(ctx, generator, z_ty)?; - - Ok(Some(builtin_fns::call_scipy_special_j0(generator, ctx, (z_ty, z_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "sp_spec_j1", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_scipy_special_j1(generator, ctx, (x_ty, x_val))?)) - }), - ), - // Not mapped: jv/yv, libm only supports integer orders. - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_arctan2".into(), - simple_name: "np_arctan2".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty - .iter() - .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) - .collect(), - ret: ret_ty.0, - vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - })), - var_id: vec![ret_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = - args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_arctan2( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - }, - )))), - loc: None, - })) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_copysign".into(), - simple_name: "np_copysign".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty - .iter() - .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) - .collect(), - ret: ret_ty.0, - vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - })), - var_id: vec![ret_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = - args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_copysign( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - }, - )))), - loc: None, - })) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_fmax".into(), - simple_name: "np_fmax".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty - .iter() - .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) - .collect(), - ret: ret_ty.0, - vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = - args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_fmax( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - }, - )))), - loc: None, - })) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_fmin".into(), - simple_name: "np_fmin".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty - .iter() - .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) - .collect(), - ret: ret_ty.0, - vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = - args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_fmin( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - }, - )))), - loc: None, - })) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, int32); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_ldexp".into(), - simple_name: "np_ldexp".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty - .iter() - .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) - .collect(), - ret: ret_ty.0, - vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = - args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_ldexp( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - }, - )))), - loc: None, - })) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_hypot".into(), - simple_name: "np_hypot".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty - .iter() - .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) - .collect(), - ret: ret_ty.0, - vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = - args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_hypot( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - }, - )))), - loc: None, - })) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_nextafter".into(), - simple_name: "np_nextafter".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty - .iter() - .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) - .collect(), - ret: ret_ty.0, - vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] - .into_iter() - .collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = - args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = - args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_nextafter( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - }, - )))), - loc: None, - })) - }, - Arc::new(RwLock::new(TopLevelDef::Function { - name: "Some".into(), - simple_name: "Some".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: option_ty_var, default_value: None }], - ret: primitives.option, - vars: VarMap::from([(option_ty_var_id, option_ty_var)]), + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), + ret: ret_ty.0, + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), - var_id: vec![option_ty_var_id], + var_id: vec![ret_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let alloca = generator - .gen_var_alloc(ctx, arg_val.get_type(), Some("alloca_some")) - .unwrap(); - ctx.builder.build_store(alloca, arg_val).unwrap(); - Ok(Some(alloca.into())) + move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + + let func = match prim { + PrimDef::FunNpArctan2 => builtin_fns::call_numpy_arctan2, + PrimDef::FunNpCopysign => builtin_fns::call_numpy_copysign, + PrimDef::FunNpFmax => builtin_fns::call_numpy_fmax, + PrimDef::FunNpFmin => builtin_fns::call_numpy_fmin, + PrimDef::FunNpLdExp => builtin_fns::call_numpy_ldexp, + PrimDef::FunNpHypot => builtin_fns::call_numpy_hypot, + PrimDef::FunNpNextAfter => builtin_fns::call_numpy_nextafter, + _ => unreachable!(), + }; + + Ok(Some(func(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) }, )))), loc: None, - })), - ]; + } + } - let ast_list: Vec>> = (0..top_level_def_list.len()).map(|_| None).collect(); + fn create_method(prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { + (prim.simple_name().into(), method_ty, prim.id()) + } - izip!(top_level_def_list, ast_list).collect_vec() + fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> (Type, u32) { + let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None); + + self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None) + } } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 17213637..b8d56898 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -102,6 +102,12 @@ pub enum PrimDef { FunSome, } +/// Associated details of a [`PrimDef`] +pub enum PrimDefDetails { + PrimFunction { name: &'static str, simple_name: &'static str }, + PrimClass { name: &'static str }, +} + impl PrimDef { /// Get the assigned [`DefinitionId`] of this [`PrimDef`]. /// @@ -117,6 +123,150 @@ impl PrimDef { pub fn contains_id(id: DefinitionId) -> bool { Self::iter().any(|prim| prim.id() == id) } + + /// Get the definition "simple name" of this [`PrimDef`]. + /// + /// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::simple_name`]. + /// + /// If the [`PrimDef`] is a class, this returns [`None`]. + #[must_use] + pub fn simple_name(&self) -> &'static str { + match self.details() { + PrimDefDetails::PrimFunction { simple_name, .. } => simple_name, + PrimDefDetails::PrimClass { .. } => { + panic!("PrimDef {self:?} has no simple_name as it is not a function.") + } + } + } + + /// Get the definition "name" of this [`PrimDef`]. + /// + /// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::name`]. + /// + /// If the [`PrimDef`] is a class, this corresponds to [`TopLevelDef::Class::name`]. + #[must_use] + pub fn name(&self) -> &'static str { + match self.details() { + PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name, + } + } + + /// Get the associated details of this [`PrimDef`] + #[must_use] + pub fn details(self) -> PrimDefDetails { + fn class(name: &'static str) -> PrimDefDetails { + PrimDefDetails::PrimClass { name } + } + + fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails { + PrimDefDetails::PrimFunction { simple_name: simple_name.unwrap_or(name), name } + } + + match self { + PrimDef::Int32 => class("int32"), + PrimDef::Int64 => class("int64"), + PrimDef::Float => class("float"), + PrimDef::Bool => class("bool"), + PrimDef::None => class("none"), + PrimDef::Range => class("range"), + PrimDef::Str => class("str"), + PrimDef::Exception => class("Exception"), + PrimDef::UInt32 => class("uint32"), + PrimDef::UInt64 => class("uint64"), + PrimDef::Option => class("Option"), + PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")), + PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")), + PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")), + PrimDef::NDArray => class("ndarray"), + PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")), + PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")), + PrimDef::FunInt32 => fun("int32", None), + PrimDef::FunInt64 => fun("int64", None), + PrimDef::FunUInt32 => fun("uint32", None), + PrimDef::FunUInt64 => fun("uint64", None), + PrimDef::FunFloat => fun("float", None), + PrimDef::FunNpNDArray => fun("np_ndarray", None), + PrimDef::FunNpEmpty => fun("np_empty", None), + PrimDef::FunNpZeros => fun("np_zeros", None), + PrimDef::FunNpOnes => fun("np_ones", None), + PrimDef::FunNpFull => fun("np_full", None), + PrimDef::FunNpArray => fun("np_array", None), + PrimDef::FunNpEye => fun("np_eye", None), + PrimDef::FunNpIdentity => fun("np_identity", None), + PrimDef::FunRound => fun("round", None), + PrimDef::FunRound64 => fun("round64", None), + PrimDef::FunNpRound => fun("np_round", None), + PrimDef::FunRange => fun("range", None), + PrimDef::FunStr => fun("str", None), + PrimDef::FunBool => fun("bool", None), + PrimDef::FunFloor => fun("floor", None), + PrimDef::FunFloor64 => fun("floor64", None), + PrimDef::FunNpFloor => fun("np_floor", None), + PrimDef::FunCeil => fun("ceil", None), + PrimDef::FunCeil64 => fun("ceil64", None), + PrimDef::FunNpCeil => fun("np_ceil", None), + PrimDef::FunLen => fun("len", None), + PrimDef::FunMin => fun("min", None), + PrimDef::FunNpMin => fun("np_min", None), + PrimDef::FunNpMinimum => fun("np_minimum", None), + PrimDef::FunMax => fun("max", None), + PrimDef::FunNpMax => fun("np_max", None), + PrimDef::FunNpMaximum => fun("np_maximum", None), + PrimDef::FunAbs => fun("abs", None), + PrimDef::FunNpIsNan => fun("np_isnan", None), + PrimDef::FunNpIsInf => fun("np_isinf", None), + PrimDef::FunNpSin => fun("np_sin", None), + PrimDef::FunNpCos => fun("np_cos", None), + PrimDef::FunNpExp => fun("np_exp", None), + PrimDef::FunNpExp2 => fun("np_exp2", None), + PrimDef::FunNpLog => fun("np_log", None), + PrimDef::FunNpLog10 => fun("np_log10", None), + PrimDef::FunNpLog2 => fun("np_log2", None), + PrimDef::FunNpFabs => fun("np_fabs", None), + PrimDef::FunNpSqrt => fun("np_sqrt", None), + PrimDef::FunNpRint => fun("np_rint", None), + PrimDef::FunNpTan => fun("np_tan", None), + PrimDef::FunNpArcsin => fun("np_arcsin", None), + PrimDef::FunNpArccos => fun("np_arccos", None), + PrimDef::FunNpArctan => fun("np_arctan", None), + PrimDef::FunNpSinh => fun("np_sinh", None), + PrimDef::FunNpCosh => fun("np_cosh", None), + PrimDef::FunNpTanh => fun("np_tanh", None), + PrimDef::FunNpArcsinh => fun("np_arcsinh", None), + PrimDef::FunNpArccosh => fun("np_arccosh", None), + PrimDef::FunNpArctanh => fun("np_arctanh", None), + PrimDef::FunNpExpm1 => fun("np_expm1", None), + PrimDef::FunNpCbrt => fun("np_cbrt", None), + PrimDef::FunSpSpecErf => fun("sp_spec_erf", None), + PrimDef::FunSpSpecErfc => fun("sp_spec_erfc", None), + PrimDef::FunSpSpecGamma => fun("sp_spec_gamma", None), + PrimDef::FunSpSpecGammaln => fun("sp_spec_gammaln", None), + PrimDef::FunSpSpecJ0 => fun("sp_spec_j0", None), + PrimDef::FunSpSpecJ1 => fun("sp_spec_j1", None), + PrimDef::FunNpArctan2 => fun("np_arctan2", None), + PrimDef::FunNpCopysign => fun("np_copysign", None), + PrimDef::FunNpFmax => fun("np_fmax", None), + PrimDef::FunNpFmin => fun("np_fmin", None), + PrimDef::FunNpLdExp => fun("np_ldexp", None), + PrimDef::FunNpHypot => fun("np_hypot", None), + PrimDef::FunNpNextAfter => fun("np_nextafter", None), + PrimDef::FunSome => fun("Some", None), + } + } +} + +/// Asserts that a [`PrimDef`] is in an allowlist. +/// +/// Like `debug_assert!`, this statements of this function are only +/// enabled if `cfg!(debug_assertions)` is true. +pub fn debug_assert_prim_is_allowed(prim: PrimDef, allowlist: &[PrimDef]) { + if cfg!(debug_assertions) { + let allowed = allowlist.iter().any(|p| *p == prim); + assert!( + allowed, + "Disallowed primitive definition. Got {prim:?}, but expects it to be in {allowlist:?}" + ); + } } impl TopLevelDef {