From 3aaf21fcf9227c5ad78aa3906991711774490d14 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 11 Jun 2024 15:42:49 +0800 Subject: [PATCH] core: rename `PrimitiveDefinition` and refactor `get_builtins()` --- nac3artiq/src/codegen.rs | 4 +- nac3artiq/src/symbol_resolver.rs | 6 +- nac3ast/src/ast_gen.rs | 2 +- nac3core/src/codegen/builtin_fns.rs | 142 +- nac3core/src/codegen/expr.rs | 20 +- nac3core/src/codegen/mod.rs | 8 +- nac3core/src/codegen/stmt.rs | 6 +- nac3core/src/toplevel/builtins.rs | 3133 +++++++---------- nac3core/src/toplevel/helper.rs | 489 ++- nac3core/src/toplevel/numpy.rs | 6 +- nac3core/src/toplevel/type_annotation.rs | 4 +- nac3core/src/typecheck/magic_methods.rs | 28 +- nac3core/src/typecheck/type_inferencer/mod.rs | 28 +- .../src/typecheck/type_inferencer/test.rs | 50 +- 14 files changed, 1762 insertions(+), 2164 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 1f8fd23..0de90c6 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -6,7 +6,7 @@ use nac3core::{ CodeGenContext, CodeGenerator, }, symbol_resolver::ValueEnum, - toplevel::{DefinitionId, GenCall, helper::PrimitiveDefinition}, + toplevel::{DefinitionId, GenCall, helper::PrimDef}, typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap} }; @@ -670,7 +670,7 @@ pub fn attributes_writeback( vars: VarMap::default() }; let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect(); - if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, PrimitiveDefinition::Int32.id()), args, generator) { + if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator) { return Ok(Err(e)); } Ok(Ok(())) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 758e261..6f44a79 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -4,7 +4,7 @@ use nac3core::{ symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, toplevel::{ DefinitionId, - helper::PrimitiveDefinition, + helper::PrimDef, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelDef, }, @@ -468,7 +468,7 @@ impl InnerResolver { ))); } } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimitiveDefinition::NDArray.id() => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { if args.len() != 2 { return Ok(Err(format!( "type list needs exactly 2 type parameters, found {}", @@ -664,7 +664,7 @@ impl InnerResolver { } } } - (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimitiveDefinition::NDArray.id() => { + (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => { let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty); let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; if len == 0 { diff --git a/nac3ast/src/ast_gen.rs b/nac3ast/src/ast_gen.rs index b5e76df..93bd76b 100644 --- a/nac3ast/src/ast_gen.rs +++ b/nac3ast/src/ast_gen.rs @@ -54,7 +54,7 @@ impl From<&str> for StrRef { } } -impl From for String{ +impl From for String { fn from(s: StrRef) -> Self { get_str_from_ref(&get_str_ref_lock(), s).to_string() } diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 409da2e..ce1faf0 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -7,7 +7,7 @@ use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intri use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor}; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::stmt::gen_for_callback_incrementing; -use crate::toplevel::helper::PrimitiveDefinition; +use crate::toplevel::helper::PrimDef; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::Type; @@ -79,7 +79,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -150,7 +150,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -237,7 +237,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -318,7 +318,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -383,7 +383,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>( n.into() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -429,7 +429,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -469,7 +469,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_roundeven(ctx, n, None).into() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -532,7 +532,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>( .unwrap() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -588,7 +588,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>( } } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -638,7 +638,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>( } } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -732,7 +732,7 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( a } - BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); @@ -830,11 +830,11 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>( call_min(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id())) => { + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -950,7 +950,7 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( a } - BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); @@ -1048,11 +1048,11 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>( call_max(ctx, (x1_ty, x1.into()), (x2_ty, x2.into())) } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id())) => { + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -1139,7 +1139,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_fabs(ctx, n, Some(FN_NAME)).into() } - BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(n) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1179,7 +1179,7 @@ pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>( irrt::call_isnan(generator, ctx, x).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1221,7 +1221,7 @@ pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>( irrt::call_isinf(generator, ctx, x).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1263,7 +1263,7 @@ pub fn call_numpy_sin<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_sin(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1303,7 +1303,7 @@ pub fn call_numpy_cos<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_cos(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1343,7 +1343,7 @@ pub fn call_numpy_exp<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_exp(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1383,7 +1383,7 @@ pub fn call_numpy_exp2<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_exp2(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1423,7 +1423,7 @@ pub fn call_numpy_log<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_log(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1463,7 +1463,7 @@ pub fn call_numpy_log10<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_log10(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1503,7 +1503,7 @@ pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_log2(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1543,7 +1543,7 @@ pub fn call_numpy_fabs<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_fabs(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1583,7 +1583,7 @@ pub fn call_numpy_sqrt<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_sqrt(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1623,7 +1623,7 @@ pub fn call_numpy_rint<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_roundeven(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1663,7 +1663,7 @@ pub fn call_numpy_tan<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_tan(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1703,7 +1703,7 @@ pub fn call_numpy_arcsin<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_asin(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1743,7 +1743,7 @@ pub fn call_numpy_arccos<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_acos(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1783,7 +1783,7 @@ pub fn call_numpy_arctan<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_atan(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1823,7 +1823,7 @@ pub fn call_numpy_sinh<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_sinh(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1863,7 +1863,7 @@ pub fn call_numpy_cosh<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_cosh(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1903,7 +1903,7 @@ pub fn call_numpy_tanh<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_tanh(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1943,7 +1943,7 @@ pub fn call_numpy_arcsinh<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_asinh(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -1983,7 +1983,7 @@ pub fn call_numpy_arccosh<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_acosh(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2023,7 +2023,7 @@ pub fn call_numpy_arctanh<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_atanh(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2063,7 +2063,7 @@ pub fn call_numpy_expm1<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_expm1(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2103,7 +2103,7 @@ pub fn call_numpy_cbrt<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_cbrt(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2143,7 +2143,7 @@ pub fn call_scipy_special_erf<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_erf(ctx, z, None).into() } - BasicValueEnum::PointerValue(z) if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(z) if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2183,7 +2183,7 @@ pub fn call_scipy_special_erfc<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_erfc(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2223,7 +2223,7 @@ pub fn call_scipy_special_gamma<'ctx, G: CodeGenerator + ?Sized>( irrt::call_gamma(ctx, z).into() } - BasicValueEnum::PointerValue(z) if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(z) if z_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, z_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2263,7 +2263,7 @@ pub fn call_scipy_special_gammaln<'ctx, G: CodeGenerator + ?Sized>( irrt::call_gammaln(ctx, x).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2303,7 +2303,7 @@ pub fn call_scipy_special_j0<'ctx, G: CodeGenerator + ?Sized>( irrt::call_j0(ctx, x).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2343,7 +2343,7 @@ pub fn call_scipy_special_j1<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_j1(ctx, x, None).into() } - BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) => { + BasicValueEnum::PointerValue(x) if x_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => { let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x_ty); let ndarray = ndarray_elementwise_unaryop_impl( @@ -2384,11 +2384,11 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_atan2(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id())) => { + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2451,11 +2451,11 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_copysign(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id())) => { + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2518,11 +2518,11 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id())) => { + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2585,11 +2585,11 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>( llvm_intrinsics::call_float_minnum(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id())) => { + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2652,11 +2652,11 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_ldexp(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id())) => { + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 { unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty).0 @@ -2708,11 +2708,11 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_hypot(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id())) => { + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); @@ -2775,11 +2775,11 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>( extern_fns::call_nextafter(ctx, x1, x2, None).into() } - (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id())) => { + (x1, x2) if [&x1_ty, &x2_ty].into_iter().any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) => { let is_ndarray1 = x1_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = x2_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let dtype = if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty); diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index b58eb62..60e9e9f 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -26,7 +26,7 @@ use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, toplevel::{ DefinitionId, - helper::PrimitiveDefinition, + helper::PrimDef, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelDef, }, @@ -1133,13 +1133,13 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Some("f_pow_i") ); Ok(Some(res.into())) - } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) { + } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); let is_ndarray1 = ty1.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = ty2.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); @@ -1373,7 +1373,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( .unwrap(), _ => val.into(), } - } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) { + } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); @@ -1385,7 +1385,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( // ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before // passing it to the elementwise codegen function - let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::Bool.id()) { + let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) { if *op == ast::Unaryop::Invert { &ast::Unaryop::Not } else { @@ -1451,7 +1451,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( let left_ty = ctx.unifier.get_representative(left.0.unwrap()); let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap()); - if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) { + if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let llvm_usize = generator.get_size_type(ctx.ctx); let (Some(left_ty), lhs) = left else { unreachable!() }; @@ -1459,9 +1459,9 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( let op = ops[0].clone(); let is_ndarray1 = left_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let is_ndarray2 = right_ty.obj_id(&ctx.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); return if is_ndarray1 && is_ndarray2 { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty); @@ -2452,7 +2452,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v.data().get(ctx, generator, &index, None).into() } } - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimitiveDefinition::NDArray.id() => { + TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { let (ty, ndims) = params.iter() .map(|(_, ty)| ty) .collect_tuple() diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 543bd54..a17ff30 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,7 +1,7 @@ use crate::{ symbol_resolver::{StaticValue, SymbolResolver}, toplevel::{ - helper::PrimitiveDefinition, + helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef, @@ -435,9 +435,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( let result = match &*ty_enum { TObj { obj_id, fields, .. } => { // check to avoid treating non-class primitives as classes - if PrimitiveDefinition::contains_id(*obj_id) { + if PrimDef::contains_id(*obj_id) { return match &*unifier.get_ty_immutable(ty) { - TObj { obj_id, params, .. } if *obj_id == PrimitiveDefinition::Option.id() => { + TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => { get_llvm_type( ctx, module, @@ -451,7 +451,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( .into() } - TObj { obj_id, .. } if *obj_id == PrimitiveDefinition::NDArray.id() => { + TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let llvm_usize = generator.get_size_type(ctx); let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let element_type = get_llvm_type( diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 99c0561..a7c4f27 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -12,7 +12,7 @@ use crate::{ }, toplevel::{ DefinitionId, - helper::PrimitiveDefinition, + helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelDef, }, @@ -192,7 +192,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( v.data().ptr_offset(ctx, generator, &index, name) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimitiveDefinition::NDArray.id() => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { todo!() } @@ -254,7 +254,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( let value = ListValue::from_ptr_val(value, llvm_usize, None); let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { TypeEnum::TList { ty } => *ty, - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimitiveDefinition::NDArray.id() => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 } _ => unreachable!(), diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 2d68716..797f820 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,13 +1,15 @@ use std::iter::once; +use helper::debug_assert_prim_is_allowed; use indexmap::IndexMap; use inkwell::{ attributes::{Attribute, AttributeLoc}, - IntPredicate, types::{BasicMetadataTypeEnum, BasicType}, - values::{BasicMetadataValueEnum, BasicValue, CallSiteValue} + values::{BasicMetadataValueEnum, BasicValue, CallSiteValue}, + IntPredicate, }; use itertools::Either; +use strum::IntoEnumIterator; use crate::{ codegen::{ @@ -19,10 +21,7 @@ use crate::{ stmt::exn_constructor, }, symbol_resolver::SymbolValue, - toplevel::{ - helper::PrimitiveDefinition, - numpy::make_ndarray_ty, - }, + toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, typecheck::typedef::VarMap, }; @@ -35,8 +34,8 @@ pub fn get_exn_constructor( class_id: usize, cons_id: usize, unifier: &mut Unifier, - primitives: &PrimitiveStore -)-> (TopLevelDef, TopLevelDef, Type, Type) { + primitives: &PrimitiveStore, +) -> (TopLevelDef, TopLevelDef, Type, Type) { let int32 = primitives.int32; let int64 = primitives.int64; let string = primitives.str; @@ -90,7 +89,7 @@ pub fn get_exn_constructor( methods: vec![("__init__".into(), signature, DefinitionId(cons_id))], ancestors: vec![ TypeAnnotation::CustomClass { id: DefinitionId(class_id), params: Vec::default() }, - TypeAnnotation::CustomClass { id: PrimitiveDefinition::Exception.id(), params: Vec::default() }, + TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() }, ], constructor: Some(signature), resolver: None, @@ -113,16 +112,15 @@ fn create_fn_by_codegen( ret_ty: Type, param_ty: &[(Type, &'static str)], codegen_callback: Box, -) -> Arc> { - Arc::new(RwLock::new(TopLevelDef::Function { +) -> TopLevelDef { + TopLevelDef::Function { name: name.into(), simple_name: name.into(), signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), ret: ret_ty, vars: var_map.clone(), })), @@ -132,7 +130,7 @@ fn create_fn_by_codegen( resolver: None, codegen_callback: Some(Arc::new(GenCall::new(codegen_callback))), loc: None, - })) + } } /// Creates a NumPy [`TopLevelDef`] function using an LLVM intrinsic. @@ -149,10 +147,8 @@ fn create_fn_by_intrinsic( ret_ty: Type, params: &[(Type, &'static str)], intrinsic_fn: &'static str, -) -> Arc> { - let param_tys = params.iter() - .map(|p| p.0) - .collect_vec(); +) -> TopLevelDef { + let param_tys = params.iter().map(|p| p.0).collect_vec(); create_fn_by_codegen( unifier, @@ -163,21 +159,22 @@ fn create_fn_by_intrinsic( Box::new(move |ctx, _, fun, args, generator| { let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec(); - assert!(param_tys.iter().zip(&args_ty) + assert!(param_tys + .iter() + .zip(&args_ty) .all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual))); - let args_val = args_ty.iter().zip_eq(args.iter()) - .map(|(ty, arg)| { - arg.1.clone() - .to_basic_value_enum(ctx, generator, *ty) - .unwrap() - }) + let args_val = args_ty + .iter() + .zip_eq(args.iter()) + .map(|(ty, arg)| arg.1.clone().to_basic_value_enum(ctx, generator, *ty).unwrap()) .map_into::() .collect_vec(); let intrinsic_fn = ctx.module.get_function(intrinsic_fn).unwrap_or_else(|| { let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty); - let param_llvm_ty = param_tys.iter() + let param_llvm_ty = param_tys + .iter() .map(|p| ctx.get_llvm_abi_type(generator, *p)) .map_into::() .collect_vec(); @@ -186,7 +183,8 @@ fn create_fn_by_intrinsic( ctx.module.add_function(intrinsic_fn, fn_type, None) }); - let val = ctx.builder + let val = ctx + .builder .build_call(intrinsic_fn, args_val.as_slice(), name) .map(CallSiteValue::try_as_basic_value) .map(Either::unwrap_left) @@ -214,10 +212,8 @@ fn create_fn_by_extern( params: &[(Type, &'static str)], extern_fn: &'static str, attrs: &'static [&str], -) -> Arc> { - let param_tys = params.iter() - .map(|p| p.0) - .collect_vec(); +) -> TopLevelDef { + let param_tys = params.iter().map(|p| p.0).collect_vec(); create_fn_by_codegen( unifier, @@ -228,694 +224,837 @@ fn create_fn_by_extern( Box::new(move |ctx, _, fun, args, generator| { let args_ty = fun.0.args.iter().map(|a| a.ty).collect_vec(); - assert!(param_tys.iter().zip(&args_ty) + assert!(param_tys + .iter() + .zip(&args_ty) .all(|(expected, actual)| ctx.unifier.unioned(*expected, *actual))); - let args_val = args_ty.iter().zip_eq(args.iter()) - .map(|(ty, arg)| { - arg.1.clone() - .to_basic_value_enum(ctx, generator, *ty) - .unwrap() - }) + let args_val = args_ty + .iter() + .zip_eq(args.iter()) + .map(|(ty, arg)| arg.1.clone().to_basic_value_enum(ctx, generator, *ty).unwrap()) .map_into::() .collect_vec(); - let intrinsic_fn = ctx.module.get_function(extern_fn).unwrap_or_else(|| { - let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty); - let param_llvm_ty = param_tys.iter() - .map(|p| ctx.get_llvm_abi_type(generator, *p)) - .map_into::() - .collect_vec(); - let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false); - let func = ctx.module.add_function(extern_fn, fn_type, None); + let intrinsic_fn = ctx.module.get_function(extern_fn).unwrap_or_else(|| { + let ret_llvm_ty = ctx.get_llvm_abi_type(generator, ret_ty); + let param_llvm_ty = param_tys + .iter() + .map(|p| ctx.get_llvm_abi_type(generator, *p)) + .map_into::() + .collect_vec(); + let fn_type = ret_llvm_ty.fn_type(param_llvm_ty.as_slice(), false); + let func = ctx.module.add_function(extern_fn, fn_type, None); + func.add_attribute( + AttributeLoc::Function, + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0), + ); + + for attr in attrs { func.add_attribute( AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) + ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0), ); + } - for attr in attrs { - func.add_attribute( - AttributeLoc::Function, - ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) - ); - } + func + }); - func - }); - - let val = ctx.builder - .build_call(intrinsic_fn, &args_val, name) - .map(CallSiteValue::try_as_basic_value) - .map(Either::unwrap_left) - .unwrap(); - Ok(val.into()) + let val = ctx + .builder + .build_call(intrinsic_fn, &args_val, name) + .map(CallSiteValue::try_as_basic_value) + .map(Either::unwrap_left) + .unwrap(); + Ok(val.into()) }), ) } pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> BuiltinInfo { - let PrimitiveStore { - int32, - int64, - uint32, - uint64, - float, - bool: boolean, - range, - str: string, - ndarray, - .. - } = *primitives; + let top_level_def_list = BuiltinBuilder::new(unifier, primitives) + .build_all_builtins() + .into_iter() + .map(|tld| Arc::new(RwLock::new(tld))); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None); - let ndarray_float_2d = { - let value = match primitives.size_t { - 64 => SymbolValue::U64(2u64), - 32 => SymbolValue::U32(2u32), - _ => unreachable!(), - }; - let ndims = unifier.add_ty(TypeEnum::TLiteral { - values: vec![value], - loc: None, - }); + let ast_list: Vec>> = (0..top_level_def_list.len()).map(|_| None).collect(); - make_ndarray_ty(unifier, primitives, Some(float), Some(ndims)) - }; - let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 }); - let num_ty = unifier.get_fresh_var_with_range( - &[int32, int64, float, boolean, uint32, uint64], - Some("N".into()), - None, - ); - let num_var_map: VarMap = vec![ - (num_ty.1, num_ty.0), - ].into_iter().collect(); + izip!(top_level_def_list, ast_list).collect_vec() +} - let new_type_or_ndarray_ty = |unifier: &mut Unifier, primitives: &PrimitiveStore, scalar_ty: Type| { - let ndarray = make_ndarray_ty(unifier, primitives, Some(scalar_ty), None); +/// A helper enum used by [`BuiltinBuilder`] +#[derive(Clone, Copy)] +enum SizeVariant { + Bits32, + Bits64, +} - unifier.get_fresh_var_with_range( - &[scalar_ty, ndarray], - Some("T".into()), - None, - ) - }; +impl SizeVariant { + fn of_int(self, primitives: &PrimitiveStore) -> Type { + match self { + SizeVariant::Bits32 => primitives.int32, + SizeVariant::Bits64 => primitives.int64, + } + } +} - let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.0), None); - let float_or_ndarray_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let float_or_ndarray_var_map: VarMap = vec![ - (float_or_ndarray_ty.1, float_or_ndarray_ty.0), - ].into_iter().collect(); +struct BuiltinBuilder<'a> { + unifier: &'a mut Unifier, + primitives: &'a PrimitiveStore, - let num_or_ndarray_ty = unifier.get_fresh_var_with_range( - &[num_ty.0, ndarray_num_ty], - Some("T".into()), - None, - ); - let num_or_ndarray_var_map: VarMap = vec![ - (num_ty.1, num_ty.0), - (num_or_ndarray_ty.1, num_or_ndarray_ty.0), - ].into_iter().collect(); + is_some_ty: (Type, bool), + unwrap_ty: (Type, bool), + option_tvar: (Type, u32), - let exception_fields = vec![ - ("__name__".into(), int32, true), - ("__file__".into(), string, true), - ("__line__".into(), int32, true), - ("__col__".into(), int32, true), - ("__func__".into(), string, true), - ("__message__".into(), string, true), - ("__param0__".into(), int64, true), - ("__param1__".into(), int64, true), - ("__param2__".into(), int64, true), - ]; + ndarray_dtype_tvar: (Type, u32), + ndarray_ndims_tvar: (Type, u32), + ndarray_copy_ty: (Type, bool), + ndarray_fill_ty: (Type, bool), - // for Option, is_some and is_none share the same type: () -> bool, - // and they are methods under the same class `Option` - let (is_some_ty, unwrap_ty, (option_ty_var, option_ty_var_id)) = - if let TypeEnum::TObj { fields, params, .. } = - unifier.get_ty(primitives.option).as_ref() - { - ( - *fields.get(&"is_some".into()).unwrap(), - *fields.get(&"unwrap".into()).unwrap(), - (*params.iter().next().unwrap().1, *params.iter().next().unwrap().0), - ) - } else { + list_int32: Type, + + num_ty: (Type, u32), + num_var_map: VarMap, + + ndarray_float: Type, + ndarray_float_2d: Type, + ndarray_num_ty: Type, + + float_or_ndarray_ty: (Type, u32), + float_or_ndarray_var_map: VarMap, + + num_or_ndarray_ty: (Type, u32), + num_or_ndarray_var_map: VarMap, +} + +impl<'a> BuiltinBuilder<'a> { + fn new(unifier: &'a mut Unifier, primitives: &'a PrimitiveStore) -> BuiltinBuilder<'a> { + let PrimitiveStore { + int32, + int64, + uint32, + uint64, + float, + bool: boolean, + ndarray, + option, + .. + } = *primitives; + + // Option-related + let (is_some_ty, unwrap_ty, option_tvar) = + if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() { + ( + *fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(), + *fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(), + (*params.iter().next().unwrap().1, *params.iter().next().unwrap().0), + ) + } else { + unreachable!() + }; + + let TypeEnum::TObj { fields: ndarray_fields, params: ndarray_params, .. } = + &*unifier.get_ty(ndarray) + else { unreachable!() }; + let ndarray_dtype_tvar = + ndarray_params.iter().next().map(|(var_id, ty)| (*ty, *var_id)).unwrap(); + let ndarray_ndims_tvar = + ndarray_params.iter().nth(1).map(|(var_id, ty)| (*ty, *var_id)).unwrap(); + let ndarray_copy_ty = + *ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap(); + let ndarray_fill_ty = + *ndarray_fields.get(&PrimDef::NDArrayFill.simple_name().into()).unwrap(); - let TypeEnum::TObj { - fields: ndarray_fields, - params: ndarray_params, - .. - } = &*unifier.get_ty(primitives.ndarray) else { - unreachable!() - }; + let num_ty = unifier.get_fresh_var_with_range( + &[int32, int64, float, boolean, uint32, uint64], + Some("N".into()), + None, + ); + let num_var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect(); - let (ndarray_dtype_ty, ndarray_dtype_var_id) = ndarray_params - .iter() - .next() - .map(|(var_id, ty)| (*ty, *var_id)) - .unwrap(); - let (ndarray_ndims_ty, ndarray_ndims_var_id) = ndarray_params - .iter() - .nth(1) - .map(|(var_id, ty)| (*ty, *var_id)) - .unwrap(); - let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap(); - let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap(); + let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), None); + let ndarray_float_2d = { + let value = match primitives.size_t { + 64 => SymbolValue::U64(2u64), + 32 => SymbolValue::U32(2u32), + _ => unreachable!(), + }; + let ndims = unifier.add_ty(TypeEnum::TLiteral { values: vec![value], loc: None }); - let top_level_def_list = vec![ - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimitiveDefinition::Int32.id(), - None, - "int32".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimitiveDefinition::Int64.id(), - None, - "int64".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimitiveDefinition::Float.id(), - None, - "float".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimitiveDefinition::Bool.id(), - None, - "bool".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimitiveDefinition::None.id(), - None, - "none".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimitiveDefinition::Range.id(), - None, - "range".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimitiveDefinition::Str.id(), - None, - "str".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelDef::Class { - name: "Exception".into(), - object_id: PrimitiveDefinition::Exception.id(), - type_vars: Vec::default(), - fields: exception_fields, - methods: Vec::default(), - ancestors: vec![], - constructor: None, - resolver: None, - loc: None, - })), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimitiveDefinition::UInt32.id(), - None, - "uint32".into(), - None, - None, - ))), - Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( - PrimitiveDefinition::UInt64.id(), - None, - "uint64".into(), - None, - None, - ))), - Arc::new(RwLock::new({ - TopLevelDef::Class { - name: "Option".into(), - object_id: PrimitiveDefinition::Option.id(), - type_vars: vec![option_ty_var], + make_ndarray_ty(unifier, primitives, Some(float), Some(ndims)) + }; + + let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.0), None); + let float_or_ndarray_ty = + unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let float_or_ndarray_var_map: VarMap = + vec![(float_or_ndarray_ty.1, float_or_ndarray_ty.0)].into_iter().collect(); + + let num_or_ndarray_ty = + unifier.get_fresh_var_with_range(&[num_ty.0, ndarray_num_ty], Some("T".into()), None); + let num_or_ndarray_var_map: VarMap = + vec![(num_ty.1, num_ty.0), (num_or_ndarray_ty.1, num_or_ndarray_ty.0)] + .into_iter() + .collect(); + + let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 }); + + BuiltinBuilder { + unifier, + primitives, + + is_some_ty, + unwrap_ty, + option_tvar, + + ndarray_dtype_tvar, + ndarray_ndims_tvar, + ndarray_copy_ty, + ndarray_fill_ty, + + list_int32, + + num_ty, + num_var_map, + + ndarray_float, + ndarray_float_2d, + ndarray_num_ty, + + float_or_ndarray_ty, + float_or_ndarray_var_map, + + num_or_ndarray_ty, + num_or_ndarray_var_map, + } + } + + /// Construct every function from every [`PrimDef`], in the order of [`PrimDef`]'s definition. + fn build_all_builtins(&mut self) -> Vec { + PrimDef::iter().map(|prim| self.build_builtin_of_prim(prim)).collect_vec() + } + + fn build_builtin_of_prim(&mut self, prim: PrimDef) -> TopLevelDef { + match prim { + PrimDef::Int32 + | PrimDef::Int64 + | PrimDef::UInt32 + | PrimDef::UInt64 + | PrimDef::Float + | PrimDef::Bool + | PrimDef::Str + | PrimDef::Range + | PrimDef::None => self.build_simple_primitive_class(prim), + + PrimDef::Exception => self.build_exception_class_related(prim), + + PrimDef::Option + | PrimDef::OptionIsSome + | PrimDef::OptionIsNone + | PrimDef::OptionUnwrap + | PrimDef::FunSome => self.build_option_class_related(prim), + + PrimDef::NDArray | PrimDef::NDArrayCopy | PrimDef::NDArrayFill => { + self.build_ndarray_class_related(prim) + } + + PrimDef::FunInt32 + | PrimDef::FunInt64 + | PrimDef::FunUInt32 + | PrimDef::FunUInt64 + | PrimDef::FunFloat + | PrimDef::FunBool => self.build_cast_function(prim), + + PrimDef::FunNpNDArray + | PrimDef::FunNpEmpty + | PrimDef::FunNpZeros + | PrimDef::FunNpOnes => self.build_ndarray_from_shape_factory_function(prim), + + PrimDef::FunNpFull | PrimDef::FunNpEye | PrimDef::FunNpIdentity => { + self.build_ndarray_other_factory_function(prim) + } + + PrimDef::FunRange => self.build_range_function(), + PrimDef::FunStr => self.build_str_function(), + + PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { + self.build_ceil_floor_function(prim) + } + + PrimDef::FunAbs => self.build_abs_function(), + + PrimDef::FunRound | PrimDef::FunRound64 => self.build_round_function(prim), + + PrimDef::FunNpFloor | PrimDef::FunNpCeil => self.build_np_ceil_floor_function(prim), + + PrimDef::FunNpRound => self.build_np_round_function(), + + PrimDef::FunLen => self.build_len_function(), + + PrimDef::FunMin | PrimDef::FunMax => self.build_min_max_function(prim), + + PrimDef::FunNpMin | PrimDef::FunNpMax => self.build_np_min_max_function(prim), + + PrimDef::FunNpMinimum | PrimDef::FunNpMaximum => { + self.build_np_minimum_maximum_function(prim) + } + + PrimDef::FunNpIsNan | PrimDef::FunNpIsInf => self.build_np_float_to_bool_function(prim), + + PrimDef::FunNpSin + | PrimDef::FunNpCos + | PrimDef::FunNpTan + | PrimDef::FunNpArcsin + | PrimDef::FunNpArccos + | PrimDef::FunNpArctan + | PrimDef::FunNpSinh + | PrimDef::FunNpCosh + | PrimDef::FunNpTanh + | PrimDef::FunNpArcsinh + | PrimDef::FunNpArccosh + | PrimDef::FunNpArctanh + | PrimDef::FunNpExp + | PrimDef::FunNpExp2 + | PrimDef::FunNpExpm1 + | PrimDef::FunNpLog + | PrimDef::FunNpLog2 + | PrimDef::FunNpLog10 + | PrimDef::FunNpSqrt + | PrimDef::FunNpCbrt + | PrimDef::FunNpFabs + | PrimDef::FunNpRint + | PrimDef::FunSpSpecErf + | PrimDef::FunSpSpecErfc + | PrimDef::FunSpSpecGamma + | PrimDef::FunSpSpecGammaln + | PrimDef::FunSpSpecJ0 + | PrimDef::FunSpSpecJ1 => self.build_np_sp_float_or_ndarray_1ary_function(prim), + + PrimDef::FunNpArctan2 + | PrimDef::FunNpCopysign + | PrimDef::FunNpFmax + | PrimDef::FunNpFmin + | PrimDef::FunNpLdExp + | PrimDef::FunNpHypot + | PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim), + } + } + + fn build_simple_primitive_class(&self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + [ + PrimDef::Int32, + PrimDef::Int64, + PrimDef::UInt32, + PrimDef::UInt64, + PrimDef::Float, + PrimDef::Bool, + PrimDef::Str, + PrimDef::Range, + PrimDef::None, + ], + ); + + TopLevelComposer::make_top_level_class_def(prim.id(), None, prim.name().into(), None, None) + } + + fn build_exception_class_related(&self, prim: PrimDef) -> TopLevelDef { + // NOTE: currently only contains the class `Exception` + debug_assert_prim_is_allowed(prim, [PrimDef::Exception]); + + let PrimitiveStore { int32, int64, str, .. } = *self.primitives; + + match prim { + PrimDef::Exception => { + let exception_fields: Vec<(StrRef, Type, bool)> = vec![ + ("__name__".into(), int32, true), + ("__file__".into(), str, true), + ("__line__".into(), int32, true), + ("__col__".into(), int32, true), + ("__func__".into(), str, true), + ("__message__".into(), str, true), + ("__param0__".into(), int64, true), + ("__param1__".into(), int64, true), + ("__param2__".into(), int64, true), + ]; + + TopLevelDef::Class { + name: prim.name().into(), + object_id: prim.id(), + type_vars: Vec::default(), + fields: exception_fields, + methods: Vec::default(), + ancestors: vec![], + constructor: None, + resolver: None, + loc: None, + } + } + _ => unreachable!(), + } + } + + fn build_option_class_related(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + [ + PrimDef::Option, + PrimDef::OptionIsSome, + PrimDef::OptionIsNone, + PrimDef::OptionUnwrap, + PrimDef::FunSome, + ], + ); + + match prim { + PrimDef::Option => TopLevelDef::Class { + name: prim.name().into(), + object_id: prim.id(), + type_vars: vec![self.option_tvar.0], fields: vec![], methods: vec![ - ("is_some".into(), is_some_ty.0, PrimitiveDefinition::OptionIsSome.id()), - ("is_none".into(), is_some_ty.0, PrimitiveDefinition::OptionIsNone.id()), - ("unwrap".into(), unwrap_ty.0, PrimitiveDefinition::OptionUnwrap.id()), + self.create_method(PrimDef::OptionIsSome, self.is_some_ty.0), + self.create_method(PrimDef::OptionIsNone, self.is_some_ty.0), + self.create_method(PrimDef::OptionUnwrap, self.unwrap_ty.0), ], ancestors: vec![TypeAnnotation::CustomClass { - id: PrimitiveDefinition::Option.id(), + id: prim.id(), params: Vec::default(), }], constructor: None, resolver: None, loc: None, - } - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "Option.is_some".into(), - simple_name: "is_some".into(), - signature: is_some_ty.0, - var_id: vec![option_ty_var_id], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, obj, _, _, generator| { - let expect_ty = obj.clone().unwrap().0; - let obj_val = obj.unwrap().1.clone().to_basic_value_enum( - ctx, - generator, - expect_ty, - )?; - let BasicValueEnum::PointerValue(ptr) = obj_val else { - unreachable!("option must be ptr") - }; + }, - Ok(Some(ctx.builder - .build_is_not_null(ptr, "is_some") - .map(Into::into) - .unwrap() - )) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "Option.is_none".into(), - simple_name: "is_none".into(), - signature: is_some_ty.0, - var_id: vec![option_ty_var_id], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, obj, _, _, generator| { - let expect_ty = obj.clone().unwrap().0; - let obj_val = obj.unwrap().1.clone().to_basic_value_enum( - ctx, - generator, - expect_ty, - )?; - let BasicValueEnum::PointerValue(ptr) = obj_val else { - unreachable!("option must be ptr") - }; + PrimDef::OptionUnwrap => TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unwrap_ty.0, + var_id: vec![self.option_tvar.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::create_dummy(String::from( + "handled in gen_expr", + )))), + loc: None, + }, - Ok(Some(ctx.builder - .build_is_null(ptr, "is_none") - .map(Into::into) - .unwrap() - )) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "Option.unwrap".into(), - simple_name: "unwrap".into(), - signature: unwrap_ty.0, - var_id: vec![option_ty_var_id], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::create_dummy( - String::from("handled in gen_expr"), - ))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Class { - name: "ndarray".into(), - object_id: PrimitiveDefinition::NDArray.id(), - type_vars: vec![ndarray_dtype_ty, ndarray_ndims_ty], - fields: Vec::default(), - methods: vec![ - ("copy".into(), ndarray_copy_ty.0, PrimitiveDefinition::NDArrayCopy.id()), - ("fill".into(), ndarray_fill_ty.0, PrimitiveDefinition::NDArrayFill.id()), - ], - ancestors: Vec::default(), - constructor: None, - resolver: None, - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "ndarray.copy".into(), - simple_name: "copy".into(), - signature: ndarray_copy_ty.0, - var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, obj, fun, args, generator| { - gen_ndarray_copy(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "ndarray.fill".into(), - simple_name: "fill".into(), - signature: ndarray_fill_ty.0, - var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, obj, fun, args, generator| { - gen_ndarray_fill(ctx, &obj, fun, &args, generator)?; - Ok(None) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "int32".into(), - simple_name: "int32".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + PrimDef::OptionIsNone | PrimDef::OptionIsSome => TopLevelDef::Function { + name: prim.name().to_string(), + simple_name: prim.simple_name().into(), + signature: self.is_some_ty.0, + var_id: vec![self.option_tvar.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + move |ctx, obj, _, _, generator| { + let expect_ty = obj.clone().unwrap().0; + let obj_val = obj + .unwrap() + .1 + .clone() + .to_basic_value_enum(ctx, generator, expect_ty)?; + let BasicValueEnum::PointerValue(ptr) = obj_val else { + unreachable!("option must be ptr") + }; - Ok(Some(builtin_fns::call_int32(generator, ctx, (arg_ty, arg))?)) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "int64".into(), - simple_name: "int64".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_int64(generator, ctx, (arg_ty, arg))?)) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "uint32".into(), - simple_name: "uint32".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_uint32(generator, ctx, (arg_ty, arg))?)) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "uint64".into(), - simple_name: "uint64".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_uint64(generator, ctx, (arg_ty, arg))?)) - }, - )))), - loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "float".into(), - simple_name: "float".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_float(generator, ctx, (arg_ty, arg))?)) - }, - )))), - loc: None, - })), - create_fn_by_codegen( - unifier, - &VarMap::new(), - "np_ndarray", - ndarray_float, - // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a - // type variable - &[(list_int32, "shape")], - Box::new(|ctx, obj, fun, args, generator| { - gen_ndarray_empty(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }), - ), - create_fn_by_codegen( - unifier, - &VarMap::new(), - "np_empty", - ndarray_float, - // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a - // type variable - &[(list_int32, "shape")], - Box::new(|ctx, obj, fun, args, generator| { - gen_ndarray_empty(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }), - ), - create_fn_by_codegen( - unifier, - &VarMap::new(), - "np_zeros", - ndarray_float, - // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a - // type variable - &[(list_int32, "shape")], - Box::new(|ctx, obj, fun, args, generator| { - gen_ndarray_zeros(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }), - ), - create_fn_by_codegen( - unifier, - &VarMap::new(), - "np_ones", - ndarray_float, - // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a - // type variable - &[(list_int32, "shape")], - Box::new(|ctx, obj, fun, args, generator| { - gen_ndarray_ones(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }), - ), - { - let tv = unifier.get_fresh_var(Some("T".into()), None); - - create_fn_by_codegen( - unifier, - &[(tv.1, tv.0)].into_iter().collect(), - "np_full", - ndarray, - // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a - // type variable - &[(list_int32, "shape"), (tv.0, "fill_value")], - Box::new(|ctx, obj, fun, args, generator| { - gen_ndarray_full(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }), - ) - }, - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_eye".into(), - simple_name: "np_eye".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { name: "N".into(), ty: int32, default_value: None }, - // TODO(Derppening): Default values current do not work? - FuncArg { - name: "M".into(), - ty: int32, - default_value: Some(SymbolValue::OptionNone) + let returned_int = match prim { + PrimDef::OptionIsNone => { + ctx.builder.build_is_null(ptr, prim.simple_name().into()) + } + PrimDef::OptionIsSome => { + ctx.builder.build_is_not_null(ptr, prim.simple_name().into()) + } + _ => unreachable!(), + }; + Ok(Some(returned_int.map(Into::into).unwrap())) }, - FuncArg { name: "k".into(), ty: int32, default_value: Some(SymbolValue::I32(0)) }, + )))), + loc: None, + }, + + PrimDef::FunSome => TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "n".into(), + ty: self.option_tvar.0, + default_value: None, + }], + ret: self.primitives.option, + vars: VarMap::from([(self.option_tvar.1, self.option_tvar.0)]), + })), + var_id: vec![self.option_tvar.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + let alloca = generator + .gen_var_alloc(ctx, arg_val.get_type(), Some("alloca_some")) + .unwrap(); + ctx.builder.build_store(alloca, arg_val).unwrap(); + Ok(Some(alloca.into())) + }, + )))), + loc: None, + }, + + _ => { + unreachable!() + } + } + } + + fn build_ndarray_class_related(&self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + [PrimDef::NDArray, PrimDef::NDArrayCopy, PrimDef::NDArrayFill], + ); + + match prim { + PrimDef::NDArray => TopLevelDef::Class { + name: prim.name().into(), + object_id: prim.id(), + type_vars: vec![self.ndarray_dtype_tvar.0, self.ndarray_ndims_tvar.0], + fields: Vec::default(), + methods: vec![ + self.create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0), + self.create_method(PrimDef::NDArrayFill, self.ndarray_fill_ty.0), ], - ret: ndarray_float_2d, - vars: VarMap::default(), + ancestors: Vec::default(), + constructor: None, + resolver: None, + loc: None, + }, + + PrimDef::NDArrayCopy => TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.ndarray_copy_ty.0, + var_id: vec![self.ndarray_dtype_tvar.1, self.ndarray_ndims_tvar.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, fun, args, generator| { + gen_ndarray_copy(ctx, &obj, fun, &args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }, + )))), + loc: None, + }, + + PrimDef::NDArrayFill => TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.ndarray_fill_ty.0, + var_id: vec![self.ndarray_dtype_tvar.1, self.ndarray_ndims_tvar.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, fun, args, generator| { + gen_ndarray_fill(ctx, &obj, fun, &args, generator)?; + Ok(None) + }, + )))), + loc: None, + }, + + _ => unreachable!(), + } + } + + fn build_cast_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + [ + PrimDef::FunInt32, + PrimDef::FunInt64, + PrimDef::FunUInt32, + PrimDef::FunUInt64, + PrimDef::FunFloat, + PrimDef::FunBool, + ], + ); + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "n".into(), + ty: self.num_or_ndarray_ty.0, + default_value: None, + }], + ret: self.num_or_ndarray_ty.0, + vars: self.num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, obj, fun, args, generator| { - gen_ndarray_eye(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) + move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + + let func = match prim { + PrimDef::FunInt32 => builtin_fns::call_int32, + PrimDef::FunInt64 => builtin_fns::call_int64, + PrimDef::FunUInt32 => builtin_fns::call_uint32, + PrimDef::FunUInt64 => builtin_fns::call_uint64, + PrimDef::FunFloat => builtin_fns::call_float, + PrimDef::FunBool => builtin_fns::call_bool, + _ => unreachable!(), + }; + Ok(Some(func(generator, ctx, (arg_ty, arg))?)) }, )))), loc: None, - })), + } + } + + fn build_round_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, [PrimDef::FunRound, PrimDef::FunRound64]); + + let float = self.primitives.float; + + let size_variant = match prim { + PrimDef::FunRound => SizeVariant::Bits32, + PrimDef::FunRound64 => SizeVariant::Bits64, + _ => unreachable!(), + }; + + let common_ndim = self.unifier.get_fresh_const_generic_var( + self.primitives.usize(), + Some("N".into()), + None, + ); + + // The size variant of the function determines the size of the returned int. + let int_sized = size_variant.of_int(self.primitives); + + let ndarray_int_sized = + make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.0)); + let ndarray_float = + make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.0)); + + let p0_ty = + self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + let ret_ty = self.unifier.get_fresh_var_with_range( + &[int_sized, ndarray_int_sized], + Some("R".into()), + None, + ); + create_fn_by_codegen( - unifier, - &VarMap::new(), - "np_identity", - ndarray_float_2d, - &[(int32, "n")], - Box::new(|ctx, obj, fun, args, generator| { - gen_ndarray_identity(ctx, &obj, fun, &args, generator) - .map(|val| Some(val.as_basic_value_enum())) - }), - ), - { - let common_ndim = unifier.get_fresh_const_generic_var( - primitives.usize(), - Some("N".into()), - None, - ); - let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - - let p0_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let ret_ty = unifier.get_fresh_var_with_range( - &[int32, ndarray_int32], - Some("R".into()), - None, - ); - - create_fn_by_codegen( - unifier, - &[ - (common_ndim.1, common_ndim.0), - (p0_ty.1, p0_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - "round", - ret_ty.0, - &[(p0_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) - }), - ) - }, - { - let common_ndim = unifier.get_fresh_const_generic_var( - primitives.usize(), - Some("N".into()), - None, - ); - let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - - let p0_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let ret_ty = unifier.get_fresh_var_with_range( - &[int64, ndarray_int64], - Some("R".into()), - None, - ); - - create_fn_by_codegen( - unifier, - &[ - (common_ndim.1, common_ndim.0), - (p0_ty.1, p0_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - "round64", - ret_ty.0, - &[(p0_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) - }), - ) - }, - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_round", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { + &mut self.unifier, + &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), + prim.name(), + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(move |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_numpy_round(generator, ctx, (arg_ty, arg))?)) + let ret_elem_ty = size_variant.of_int(&ctx.primitives); + Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) }), - ), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "range".into(), - simple_name: "range".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + ) + } + + fn build_ceil_floor_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + [PrimDef::FunFloor, PrimDef::FunFloor64, PrimDef::FunCeil, PrimDef::FunCeil64], + ); + + #[derive(Clone, Copy)] + enum Kind { + Floor, + Ceil, + } + + let (size_variant, kind) = { + match prim { + PrimDef::FunFloor => (SizeVariant::Bits32, Kind::Floor), + PrimDef::FunFloor64 => (SizeVariant::Bits64, Kind::Floor), + PrimDef::FunCeil => (SizeVariant::Bits32, Kind::Ceil), + PrimDef::FunCeil64 => (SizeVariant::Bits64, Kind::Ceil), + _ => unreachable!(), + } + }; + + let float = self.primitives.float; + + let common_ndim = self.unifier.get_fresh_const_generic_var( + self.primitives.usize(), + Some("N".into()), + None, + ); + + let ndarray_float = + make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.0)); + + // The size variant of the function determines the type of int returned + let int_sized = size_variant.of_int(self.primitives); + let ndarray_int_sized = + make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.0)); + + let p0_ty = + self.unifier.get_fresh_var_with_range(&[float, ndarray_float], Some("T".into()), None); + + let ret_ty = self.unifier.get_fresh_var_with_range( + &[int_sized, ndarray_int_sized], + Some("R".into()), + None, + ); + + create_fn_by_codegen( + &mut self.unifier, + &[(common_ndim.1, common_ndim.0), (p0_ty.1, p0_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), + prim.name(), + ret_ty.0, + &[(p0_ty.0, "n")], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + + let ret_elem_ty = size_variant.of_int(&ctx.primitives); + let func = match kind { + Kind::Ceil => builtin_fns::call_ceil, + Kind::Floor => builtin_fns::call_floor, + }; + Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) + }), + ) + } + + fn build_ndarray_from_shape_factory_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + [PrimDef::FunNpNDArray, PrimDef::FunNpEmpty, PrimDef::FunNpZeros, PrimDef::FunNpOnes], + ); + + create_fn_by_codegen( + &mut self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_float, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(self.list_int32, "shape")], + Box::new(move |ctx, obj, fun, args, generator| { + let func = match prim { + PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty, + PrimDef::FunNpZeros => gen_ndarray_zeros, + PrimDef::FunNpOnes => gen_ndarray_ones, + _ => unreachable!(), + }; + func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum())) + }), + ) + } + + fn build_ndarray_other_factory_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + [PrimDef::FunNpFull, PrimDef::FunNpEye, PrimDef::FunNpIdentity], + ); + + let int32 = self.primitives.int32; + match prim { + PrimDef::FunNpFull => { + let tv = self.unifier.get_fresh_var(Some("T".into()), None); + + create_fn_by_codegen( + &mut self.unifier, + &[(tv.1, tv.0)].into_iter().collect(), + prim.name(), + self.primitives.ndarray, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(self.list_int32, "shape"), (tv.0, "fill_value")], + Box::new(move |ctx, obj, fun, args, generator| { + gen_ndarray_full(ctx, &obj, fun, &args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ) + } + PrimDef::FunNpEye => { + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { name: "N".into(), ty: int32, default_value: None }, + // TODO(Derppening): Default values current do not work? + FuncArg { + name: "M".into(), + ty: int32, + default_value: Some(SymbolValue::OptionNone), + }, + FuncArg { + name: "k".into(), + ty: int32, + default_value: Some(SymbolValue::I32(0)), + }, + ], + ret: self.ndarray_float_2d, + vars: VarMap::default(), + })), + var_id: Vec::default(), + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, fun, args, generator| { + gen_ndarray_eye(ctx, &obj, fun, &args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }, + )))), + loc: None, + } + } + PrimDef::FunNpIdentity => create_fn_by_codegen( + &mut self.unifier, + &VarMap::new(), + prim.name(), + self.ndarray_float_2d, + &[(int32, "n")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_identity(ctx, &obj, fun, &args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ), + _ => unreachable!(), + } + } + + fn build_range_function(&mut self) -> TopLevelDef { + let prim = PrimDef::FunRange; + + let PrimitiveStore { int32, range, .. } = *self.primitives; + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ FuncArg { name: "start".into(), ty: int32, default_value: None }, FuncArg { @@ -947,13 +1086,15 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let ty_i32 = ctx.primitives.int32; for (i, arg) in args.iter().enumerate() { if arg.0 == Some("start".into()) { - start = Some(arg.1.clone().to_basic_value_enum(ctx, generator, ty_i32)?); + start = + Some(arg.1.clone().to_basic_value_enum(ctx, generator, ty_i32)?); } else if arg.0 == Some("stop".into()) { stop = Some(arg.1.clone().to_basic_value_enum(ctx, generator, ty_i32)?); } else if arg.0 == Some("step".into()) { step = Some(arg.1.clone().to_basic_value_enum(ctx, generator, ty_i32)?); } else if i == 0 { - start = Some(arg.1.clone().to_basic_value_enum(ctx, generator, ty_i32)?); + start = + Some(arg.1.clone().to_basic_value_enum(ctx, generator, ty_i32)?); } else if i == 1 { stop = Some(arg.1.clone().to_basic_value_enum(ctx, generator, ty_i32)?); } else if i == 2 { @@ -964,7 +1105,8 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built Some(step) => { let step = step.into_int_value(); // assert step != 0, throw exception if not - let not_zero = ctx.builder + let not_zero = ctx + .builder .build_int_compare( IntPredicate::NE, step, @@ -993,21 +1135,16 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built let ty = int32.array_type(3); let ptr = generator.gen_var_alloc(ctx, ty.into(), Some("range")).unwrap(); unsafe { - let a = ctx.builder - .build_in_bounds_gep(ptr, &[zero, zero], "start") + let a = + ctx.builder.build_in_bounds_gep(ptr, &[zero, zero], "start").unwrap(); + let b = ctx + .builder + .build_in_bounds_gep(ptr, &[zero, int32.const_int(1, false)], "end") .unwrap(); - let b = ctx.builder - .build_in_bounds_gep( - ptr, - &[zero, int32.const_int(1, false)], - "end", - ) + let c = ctx + .builder + .build_in_bounds_gep(ptr, &[zero, int32.const_int(2, false)], "step") .unwrap(); - let c = ctx.builder.build_in_bounds_gep( - ptr, - &[zero, int32.const_int(2, false)], - "step", - ).unwrap(); ctx.builder.build_store(a, start).unwrap(); ctx.builder.build_store(b, stop).unwrap(); ctx.builder.build_store(c, step).unwrap(); @@ -1016,13 +1153,20 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built }, )))), loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "str".into(), - simple_name: "str".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "s".into(), ty: string, default_value: None }], - ret: string, + } + } + + fn build_str_function(&mut self) -> TopLevelDef { + let prim = PrimDef::FunStr; + + let str = self.primitives.str; + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "s".into(), ty: str, default_value: None }], + ret: str, vars: VarMap::default(), })), var_id: Vec::default(), @@ -1036,505 +1180,292 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built }, )))), loc: None, - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "bool".into(), - simple_name: "bool".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), + } + } + + fn build_np_ceil_floor_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, [PrimDef::FunNpCeil, PrimDef::FunNpFloor]); + + create_fn_by_codegen( + &mut self.unifier, + &self.float_or_ndarray_var_map, + prim.name(), + self.float_or_ndarray_ty.0, + &[(self.float_or_ndarray_ty.0, "n")], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + + let func = match prim { + PrimDef::FunNpCeil => builtin_fns::call_ceil, + PrimDef::FunNpFloor => builtin_fns::call_floor, + _ => unreachable!(), + }; + Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) + }), + ) + } + + fn build_np_round_function(&mut self) -> TopLevelDef { + let prim = PrimDef::FunNpRound; + + create_fn_by_codegen( + &mut self.unifier, + &self.float_or_ndarray_var_map, + prim.name(), + self.float_or_ndarray_ty.0, + &[(self.float_or_ndarray_ty.0, "n")], + Box::new(|ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; + Ok(Some(builtin_fns::call_numpy_round(generator, ctx, (arg_ty, arg))?)) + }), + ) + } + + fn build_len_function(&mut self) -> TopLevelDef { + let prim = PrimDef::FunLen; + + let PrimitiveStore { uint64, int32, .. } = *self.primitives; + + let tvar = self.unifier.get_fresh_var(Some("L".into()), None); + let list = self.unifier.add_ty(TypeEnum::TList { ty: tvar.0 }); + let ndims = self.unifier.get_fresh_const_generic_var(uint64, Some("N".into()), None); + let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(tvar.0), Some(ndims.0)); + + let arg_ty = self.unifier.get_fresh_var_with_range( + &[list, ndarray, self.primitives.range], + Some("I".into()), + None, + ); + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }], + ret: int32, + vars: vec![(tvar.1, tvar.0), (arg_ty.1, arg_ty.0)].into_iter().collect(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { + move |ctx, _, fun, args, generator| { + let range_ty = ctx.primitives.range; let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_bool(generator, ctx, (arg_ty, arg))?)) - }, - )))), - loc: None, - })), - { - let common_ndim = unifier.get_fresh_const_generic_var( - primitives.usize(), - Some("N".into()), - None, - ); - let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - - let p0_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let ret_ty = unifier.get_fresh_var_with_range( - &[int32, ndarray_int32], - Some("R".into()), - None, - ); - - create_fn_by_codegen( - unifier, - &[ - (common_ndim.1, common_ndim.0), - (p0_ty.1, p0_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - "floor", - ret_ty.0, - &[(p0_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) - }), - ) - }, - { - let common_ndim = unifier.get_fresh_const_generic_var( - primitives.usize(), - Some("N".into()), - None, - ); - let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - - let p0_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let ret_ty = unifier.get_fresh_var_with_range( - &[int64, ndarray_int64], - Some("R".into()), - None, - ); - - create_fn_by_codegen( - unifier, - &[ - (common_ndim.1, common_ndim.0), - (p0_ty.1, p0_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - "floor64", - ret_ty.0, - &[(p0_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) - }), - ) - }, - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_floor", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_floor(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) - }), - ), - { - let common_ndim = unifier.get_fresh_const_generic_var( - primitives.usize(), - Some("N".into()), - None, - ); - let ndarray_int32 = make_ndarray_ty(unifier, primitives, Some(int32), Some(common_ndim.0)); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - - let p0_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let ret_ty = unifier.get_fresh_var_with_range( - &[int32, ndarray_int32], - Some("R".into()), - None, - ); - - create_fn_by_codegen( - unifier, - &[ - (common_ndim.1, common_ndim.0), - (p0_ty.1, p0_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - "ceil", - ret_ty.0, - &[(p0_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.int32)?)) - }), - ) - }, - { - let common_ndim = unifier.get_fresh_const_generic_var( - primitives.usize(), - Some("N".into()), - None, - ); - let ndarray_int64 = make_ndarray_ty(unifier, primitives, Some(int64), Some(common_ndim.0)); - let ndarray_float = make_ndarray_ty(unifier, primitives, Some(float), Some(common_ndim.0)); - - let p0_ty = unifier.get_fresh_var_with_range( - &[float, ndarray_float], - Some("T".into()), - None, - ); - let ret_ty = unifier.get_fresh_var_with_range( - &[int64, ndarray_int64], - Some("R".into()), - None, - ); - - create_fn_by_codegen( - unifier, - &[ - (common_ndim.1, common_ndim.0), - (p0_ty.1, p0_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - "ceil64", - ret_ty.0, - &[(p0_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.int64)?)) - }), - ) - }, - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_ceil", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "n")], - Box::new(|ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone() - .to_basic_value_enum(ctx, generator, arg_ty)?; - - Ok(Some(builtin_fns::call_ceil(generator, ctx, (arg_ty, arg), ctx.primitives.float)?)) - }), - ), - Arc::new(RwLock::new({ - let tvar = unifier.get_fresh_var(Some("L".into()), None); - let list = unifier.add_ty(TypeEnum::TList { ty: tvar.0 }); - let ndims = unifier.get_fresh_const_generic_var(primitives.uint64, Some("N".into()), None); - let ndarray = make_ndarray_ty( - unifier, - primitives, - Some(tvar.0), - Some(ndims.0), - ); - - let arg_ty = unifier.get_fresh_var_with_range( - &[list, ndarray, primitives.range], - Some("I".into()), - None, - ); - TopLevelDef::Function { - name: "len".into(), - simple_name: "len".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }], - ret: int32, - vars: vec![(tvar.1, tvar.0), (arg_ty.1, arg_ty.0)] - .into_iter() - .collect(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let range_ty = ctx.primitives.range; - let arg_ty = fun.0.args[0].ty; - let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(if ctx.unifier.unioned(arg_ty, range_ty) { - let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); - let (start, end, step) = destructure_range(ctx, arg); - Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into()) - } else { - match &*ctx.unifier.get_ty_immutable(arg_ty) { - TypeEnum::TList { .. } => { - let int32 = ctx.ctx.i32_type(); - let zero = int32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, int32.const_int(1, false)], - None, - ) - .into_int_value(); - if len.get_type().get_bit_width() == 32 { - Some(len.into()) - } else { - Some(ctx.builder + Ok(if ctx.unifier.unioned(arg_ty, range_ty) { + let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); + let (start, end, step) = destructure_range(ctx, arg); + Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into()) + } else { + match &*ctx.unifier.get_ty_immutable(arg_ty) { + TypeEnum::TList { .. } => { + let int32 = ctx.ctx.i32_type(); + let zero = int32.const_zero(); + let len = ctx + .build_gep_and_load( + arg.into_pointer_value(), + &[zero, int32.const_int(1, false)], + None, + ) + .into_int_value(); + if len.get_type().get_bit_width() == 32 { + Some(len.into()) + } else { + Some( + ctx.builder .build_int_truncate(len, int32, "len2i32") .map(Into::into) - .unwrap() - ) - } + .unwrap(), + ) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimitiveDefinition::NDArray.id() => { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); + } + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); - let arg = NDArrayValue::from_ptr_val( - arg.into_pointer_value(), - llvm_usize, - None - ); + let arg = NDArrayValue::from_ptr_val( + arg.into_pointer_value(), + llvm_usize, + None, + ); - let ndims = arg.dim_sizes().size(ctx, generator); - ctx.make_assert( - generator, - ctx.builder.build_int_compare( + let ndims = arg.dim_sizes().size(ctx, generator); + ctx.make_assert( + generator, + ctx.builder + .build_int_compare( IntPredicate::NE, ndims, llvm_usize.const_zero(), "", - ).unwrap(), - "0:TypeError", - "len() of unsized object", - [None, None, None], - ctx.current_loc, - ); - - let len = unsafe { - arg.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, ) - }; + .unwrap(), + "0:TypeError", + &format!("{name}() of unsized object", name = prim.name()), + [None, None, None], + ctx.current_loc, + ); - if len.get_type().get_bit_width() == 32 { - Some(len.into()) - } else { - Some(ctx.builder - .build_int_truncate(len, llvm_i32, "len") - .map(Into::into) - .unwrap() - ) - } + let len = unsafe { + arg.dim_sizes().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_zero(), + None, + ) + }; + + if len.get_type().get_bit_width() == 32 { + Some(len.into()) + } else { + Some( + ctx.builder + .build_int_truncate(len, llvm_i32, "len") + .map(Into::into) + .unwrap(), + ) } - _ => unreachable!(), } - }) - }, - )))), - loc: None, - } - })), - Arc::new(RwLock::new(TopLevelDef::Function { - name: "min".into(), - simple_name: "min".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { + _ => unreachable!(), + } + }) + }, + )))), + loc: None, + } + } + + fn build_min_max_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, [PrimDef::FunMin, PrimDef::FunMax]); + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ - FuncArg { name: "m".into(), ty: num_ty.0, default_value: None }, - FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }, + FuncArg { name: "m".into(), ty: self.num_ty.0, default_value: None }, + FuncArg { name: "n".into(), ty: self.num_ty.0, default_value: None }, ], - ret: num_ty.0, - vars: num_var_map.clone(), + ret: self.num_ty.0, + vars: self.num_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { + move |ctx, _, fun, args, generator| { let m_ty = fun.0.args[0].ty; let n_ty = fun.0.args[1].ty; let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?; let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; - Ok(Some(builtin_fns::call_min(ctx, (m_ty, m_val), (n_ty, n_val)))) + let func = match prim { + PrimDef::FunMin => builtin_fns::call_min, + PrimDef::FunMax => builtin_fns::call_max, + _ => unreachable!(), + }; + Ok(Some(func(ctx, (m_ty, m_val), (n_ty, n_val)))) }, )))), loc: None, - })), - { - let ret_ty = unifier.get_fresh_var(Some("R".into()), None); - let var_map = num_or_ndarray_var_map.clone() - .into_iter() - .chain(once((ret_ty.1, ret_ty.0))) - .collect::>(); + } + } - create_fn_by_codegen( - unifier, - &var_map, - "np_min", - ret_ty.0, - &[(float_or_ndarray_ty.0, "a")], - Box::new(|ctx, _, fun, args, generator| { - let a_ty = fun.0.args[0].ty; - let a = args[0].1.clone() - .to_basic_value_enum(ctx, generator, a_ty)?; + fn build_np_min_max_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, [PrimDef::FunNpMin, PrimDef::FunNpMax]); - Ok(Some(builtin_fns::call_numpy_min(generator, ctx, (a_ty, a))?)) - }), - ) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); + let ret_ty = self.unifier.get_fresh_var(Some("R".into()), None); + let var_map = self + .num_or_ndarray_var_map + .clone() + .into_iter() + .chain(once((ret_ty.1, ret_ty.0))) + .collect::>(); - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_minimum".into(), - simple_name: "np_minimum".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), - ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x2_ty = fun.0.args[1].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_minimum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - })))), - loc: None, - })) - }, - Arc::new(RwLock::new(TopLevelDef::Function { - name: "max".into(), - simple_name: "max".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { name: "m".into(), ty: num_ty.0, default_value: None }, - FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }, - ], - ret: num_ty.0, - vars: num_var_map.clone(), + create_fn_by_codegen( + &mut self.unifier, + &var_map, + prim.name(), + ret_ty.0, + &[(self.float_or_ndarray_ty.0, "a")], + Box::new(move |ctx, _, fun, args, generator| { + let a_ty = fun.0.args[0].ty; + let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?; + + let func = match prim { + PrimDef::FunNpMin => builtin_fns::call_numpy_min, + PrimDef::FunNpMax => builtin_fns::call_numpy_max, + _ => unreachable!(), + }; + + Ok(Some(func(generator, ctx, (a_ty, a))?)) + }), + ) + } + + fn build_np_minimum_maximum_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, [PrimDef::FunNpMinimum, PrimDef::FunNpMaximum]); + + let x1_ty = self.new_type_or_ndarray_ty(self.num_ty.0); + let x2_ty = self.new_type_or_ndarray_ty(self.num_ty.0); + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = self.unifier.get_fresh_var(None, None); + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), + ret: ret_ty.0, + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), - var_id: Vec::default(), + var_id: vec![x1_ty.1, x2_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let m_ty = fun.0.args[0].ty; - let n_ty = fun.0.args[1].ty; - let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?; - let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?; - - Ok(Some(builtin_fns::call_max(ctx, (m_ty, m_val), (n_ty, n_val)))) - }, - )))), - loc: None, - })), - { - let ret_ty = unifier.get_fresh_var(Some("R".into()), None); - let var_map = num_or_ndarray_var_map.clone() - .into_iter() - .chain(once((ret_ty.1, ret_ty.0))) - .collect::>(); - - create_fn_by_codegen( - unifier, - &var_map, - "np_max", - ret_ty.0, - &[(float_or_ndarray_ty.0, "a")], - Box::new(|ctx, _, fun, args, generator| { - let a_ty = fun.0.args[0].ty; - let a = args[0].1.clone() - .to_basic_value_enum(ctx, generator, a_ty)?; - - Ok(Some(builtin_fns::call_numpy_max(generator, ctx, (a_ty, a))?)) - }), - ) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, num_ty.0); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_maximum".into(), - simple_name: "np_maximum".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), - ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { + move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x2_ty = fun.0.args[1].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(builtin_fns::call_numpy_maximum(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - })))), - loc: None, - })) - }, - Arc::new(RwLock::new(TopLevelDef::Function { - name: "abs".into(), - simple_name: "abs".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }], - ret: num_or_ndarray_ty.0, - vars: num_or_ndarray_var_map.clone(), + let func = match prim { + PrimDef::FunNpMinimum => builtin_fns::call_numpy_minimum, + PrimDef::FunNpMaximum => builtin_fns::call_numpy_maximum, + _ => unreachable!(), + }; + + Ok(Some(func(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + }, + )))), + loc: None, + } + } + + fn build_abs_function(&mut self) -> TopLevelDef { + let prim = PrimDef::FunAbs; + + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "n".into(), + ty: self.num_or_ndarray_ty.0, + default_value: None, + }], + ret: self.num_or_ndarray_ty.0, + vars: self.num_or_ndarray_var_map.clone(), })), var_id: Vec::default(), instance_to_symbol: HashMap::default(), @@ -1549,763 +1480,219 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built }, )))), loc: None, - })), + } + } + + /// Build a numpy function that takes in a single float to returns a boolean + fn build_np_float_to_bool_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, [PrimDef::FunNpIsInf, PrimDef::FunNpIsNan]); + + let PrimitiveStore { bool, float, .. } = *self.primitives; + create_fn_by_codegen( - unifier, + &mut self.unifier, &VarMap::new(), - "np_isnan", - boolean, + prim.name(), + bool, &[(float, "x")], - Box::new(|ctx, _, fun, args, generator| { + Box::new(move |ctx, _, fun, args, generator| { let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?; - Ok(Some(builtin_fns::call_numpy_isnan(generator, ctx, (x_ty, x_val))?)) + let func = match prim { + PrimDef::FunNpIsInf => builtin_fns::call_numpy_isinf, + PrimDef::FunNpIsNan => builtin_fns::call_numpy_isnan, + _ => unreachable!(), + }; + + Ok(Some(func(generator, ctx, (x_ty, x_val))?)) }), - ), + ) + } + + /// Build a 1-ary numpy/scipy function that takes in a float or an ndarray and returns a value of the same type + fn build_np_sp_float_or_ndarray_1ary_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + [ + PrimDef::FunNpSin, + PrimDef::FunNpCos, + PrimDef::FunNpTan, + PrimDef::FunNpArcsin, + PrimDef::FunNpArccos, + PrimDef::FunNpArctan, + PrimDef::FunNpSinh, + PrimDef::FunNpCosh, + PrimDef::FunNpTanh, + PrimDef::FunNpArcsinh, + PrimDef::FunNpArccosh, + PrimDef::FunNpArctanh, + PrimDef::FunNpExp, + PrimDef::FunNpExp2, + PrimDef::FunNpExpm1, + PrimDef::FunNpLog, + PrimDef::FunNpLog2, + PrimDef::FunNpLog10, + PrimDef::FunNpSqrt, + PrimDef::FunNpCbrt, + PrimDef::FunNpFabs, + PrimDef::FunNpRint, + PrimDef::FunSpSpecErf, + PrimDef::FunSpSpecErfc, + PrimDef::FunSpSpecGamma, + PrimDef::FunSpSpecGammaln, + PrimDef::FunSpSpecJ0, + PrimDef::FunSpSpecJ1, + ], + ); + + // The parameter name of the sole input of this function. + // Usually this is just "x", but some functions have a different parameter name. + let arg_name = match prim { + PrimDef::FunSpSpecErf => "z", + _ => "x", + }; + create_fn_by_codegen( - unifier, - &VarMap::new(), - "np_isinf", - boolean, - &[(float, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + &mut self.unifier, + &self.float_or_ndarray_var_map, + prim.name(), + self.float_or_ndarray_ty.0, + &[(self.float_or_ndarray_ty.0, arg_name)], + Box::new(move |ctx, _, fun, args, generator| { + let arg_ty = fun.0.args[0].ty; + let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(Some(builtin_fns::call_numpy_isinf(generator, ctx, (x_ty, x_val))?)) + let func = match prim { + PrimDef::FunNpSin => builtin_fns::call_numpy_sin, + PrimDef::FunNpCos => builtin_fns::call_numpy_cos, + PrimDef::FunNpTan => builtin_fns::call_numpy_tan, + + PrimDef::FunNpArcsin => builtin_fns::call_numpy_arcsin, + PrimDef::FunNpArccos => builtin_fns::call_numpy_arccos, + PrimDef::FunNpArctan => builtin_fns::call_numpy_arctan, + + PrimDef::FunNpSinh => builtin_fns::call_numpy_sinh, + PrimDef::FunNpCosh => builtin_fns::call_numpy_cosh, + PrimDef::FunNpTanh => builtin_fns::call_numpy_tanh, + + PrimDef::FunNpArcsinh => builtin_fns::call_numpy_arcsinh, + PrimDef::FunNpArccosh => builtin_fns::call_numpy_arccosh, + PrimDef::FunNpArctanh => builtin_fns::call_numpy_arctanh, + + PrimDef::FunNpExp => builtin_fns::call_numpy_exp, + PrimDef::FunNpExp2 => builtin_fns::call_numpy_exp2, + PrimDef::FunNpExpm1 => builtin_fns::call_numpy_expm1, + + PrimDef::FunNpLog => builtin_fns::call_numpy_log, + PrimDef::FunNpLog2 => builtin_fns::call_numpy_log2, + PrimDef::FunNpLog10 => builtin_fns::call_numpy_log10, + + PrimDef::FunNpSqrt => builtin_fns::call_numpy_sqrt, + PrimDef::FunNpCbrt => builtin_fns::call_numpy_cbrt, + + PrimDef::FunNpFabs => builtin_fns::call_numpy_fabs, + PrimDef::FunNpRint => builtin_fns::call_numpy_rint, + + PrimDef::FunSpSpecErf => builtin_fns::call_scipy_special_erf, + PrimDef::FunSpSpecErfc => builtin_fns::call_scipy_special_erfc, + + PrimDef::FunSpSpecGamma => builtin_fns::call_scipy_special_gamma, + PrimDef::FunSpSpecGammaln => builtin_fns::call_scipy_special_gammaln, + + PrimDef::FunSpSpecJ0 => builtin_fns::call_scipy_special_j0, + PrimDef::FunSpSpecJ1 => builtin_fns::call_scipy_special_j1, + + _ => unreachable!(), + }; + Ok(Some(func(generator, ctx, (arg_ty, arg_val))?)) }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_sin", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + ) + } - Ok(Some(builtin_fns::call_numpy_sin(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_cos", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + /// Build a 2-ary numpy function + fn build_np_2ary_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + [ + PrimDef::FunNpArctan2, + PrimDef::FunNpCopysign, + PrimDef::FunNpFmax, + PrimDef::FunNpFmin, + PrimDef::FunNpLdExp, + PrimDef::FunNpHypot, + PrimDef::FunNpNextAfter, + ], + ); - Ok(Some(builtin_fns::call_numpy_cos(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_exp", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let PrimitiveStore { float, int32, .. } = *self.primitives; - Ok(Some(builtin_fns::call_numpy_exp(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_exp2", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let (x1_ty, x2_ty) = match prim { + PrimDef::FunNpArctan2 => (float, float), + PrimDef::FunNpCopysign => (float, float), + PrimDef::FunNpFmax => (float, float), + PrimDef::FunNpFmin => (float, float), + PrimDef::FunNpLdExp => (float, int32), + PrimDef::FunNpHypot => (float, float), + PrimDef::FunNpNextAfter => (float, float), + _ => unreachable!(), + }; - Ok(Some(builtin_fns::call_numpy_exp2(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_log", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let x1_ty = self.new_type_or_ndarray_ty(x1_ty); + let x2_ty = self.new_type_or_ndarray_ty(x2_ty); - Ok(Some(builtin_fns::call_numpy_log(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_log10", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; + let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; + let ret_ty = self.unifier.get_fresh_var(None, None); - Ok(Some(builtin_fns::call_numpy_log10(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_log2", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_log2(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_fabs", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_fabs(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_sqrt", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_sqrt(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_rint", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_rint(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_tan", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_tan(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_arcsin", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_arcsin(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_arccos", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_arccos(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_arctan", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_arctan(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_sinh", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_sinh(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_cosh", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_cosh(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_tanh", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_tanh(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_arcsinh", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_arcsinh(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_arccosh", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_arccosh(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_arctanh", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_arctanh(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_expm1", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_expm1(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "np_cbrt", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_numpy_cbrt(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "sp_spec_erf", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "z")], - Box::new(|ctx, _, fun, args, generator| { - let z_ty = fun.0.args[0].ty; - let z_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, z_ty)?; - - Ok(Some(builtin_fns::call_scipy_special_erf(generator, ctx, (z_ty, z_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "sp_spec_erfc", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let z_ty = fun.0.args[0].ty; - let z_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, z_ty)?; - - Ok(Some(builtin_fns::call_scipy_special_erfc(generator, ctx, (z_ty, z_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "sp_spec_gamma", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "z")], - Box::new(|ctx, _, fun, args, generator| { - let z_ty = fun.0.args[0].ty; - let z_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, z_ty)?; - - Ok(Some(builtin_fns::call_scipy_special_gamma(generator, ctx, (z_ty, z_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "sp_spec_gammaln", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_scipy_special_gammaln(generator, ctx, (x_ty, x_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "sp_spec_j0", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let z_ty = fun.0.args[0].ty; - let z_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, z_ty)?; - - Ok(Some(builtin_fns::call_scipy_special_j0(generator, ctx, (z_ty, z_val))?)) - }), - ), - create_fn_by_codegen( - unifier, - &float_or_ndarray_var_map, - "sp_spec_j1", - float_or_ndarray_ty.0, - &[(float_or_ndarray_ty.0, "x")], - Box::new(|ctx, _, fun, args, generator| { - let x_ty = fun.0.args[0].ty; - let x_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x_ty)?; - - Ok(Some(builtin_fns::call_scipy_special_j1(generator, ctx, (x_ty, x_val))?)) - }), - ), - // Not mapped: jv/yv, libm only supports integer orders. - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_arctan2".into(), - simple_name: "np_arctan2".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), - ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - })), - var_id: vec![ret_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_arctan2( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), - loc: None, - })) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_copysign".into(), - simple_name: "np_copysign".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), - ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - })), - var_id: vec![ret_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_copysign( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), - loc: None, - })) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_fmax".into(), - simple_name: "np_fmax".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), - ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_fmax( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), - loc: None, - })) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_fmin".into(), - simple_name: "np_fmin".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), - ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_fmin( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), - loc: None, - })) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, int32); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_ldexp".into(), - simple_name: "np_ldexp".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), - ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_ldexp( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), - loc: None, - })) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_hypot".into(), - simple_name: "np_hypot".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), - ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_hypot( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), - loc: None, - })) - }, - { - let x1_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let x2_ty = new_type_or_ndarray_ty(unifier, primitives, float); - let param_ty = &[(x1_ty.0, "x1"), (x2_ty.0, "x2")]; - let ret_ty = unifier.get_fresh_var(None, None); - - Arc::new(RwLock::new(TopLevelDef::Function { - name: "np_nextafter".into(), - simple_name: "np_nextafter".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: param_ty.iter().map(|p| FuncArg { - name: p.1.into(), - ty: p.0, - default_value: None, - }).collect(), - ret: ret_ty.0, - vars: [ - (x1_ty.1, x1_ty.0), - (x2_ty.1, x2_ty.0), - (ret_ty.1, ret_ty.0), - ].into_iter().collect(), - })), - var_id: vec![x1_ty.1, x2_ty.1], - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new(|ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone() - .to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone() - .to_basic_value_enum(ctx, generator, x2_ty)?; - - Ok(Some(builtin_fns::call_numpy_nextafter( - generator, - ctx, - (x1_ty, x1_val), - (x2_ty, x2_val), - )?)) - })))), - loc: None, - })) - }, - Arc::new(RwLock::new(TopLevelDef::Function { - name: "Some".into(), - simple_name: "Some".into(), - signature: unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![FuncArg { name: "n".into(), ty: option_ty_var, default_value: None }], - ret: primitives.option, - vars: VarMap::from([(option_ty_var_id, option_ty_var)]), + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: param_ty + .iter() + .map(|p| FuncArg { name: p.1.into(), ty: p.0, default_value: None }) + .collect(), + ret: ret_ty.0, + vars: [(x1_ty.1, x1_ty.0), (x2_ty.1, x2_ty.0), (ret_ty.1, ret_ty.0)] + .into_iter() + .collect(), })), - var_id: vec![option_ty_var_id], + var_id: vec![ret_ty.1], instance_to_symbol: HashMap::default(), instance_to_stmt: HashMap::default(), resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, fun, args, generator| { - let arg_ty = fun.0.args[0].ty; - let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let alloca = generator.gen_var_alloc(ctx, arg_val.get_type(), Some("alloca_some")).unwrap(); - ctx.builder.build_store(alloca, arg_val).unwrap(); - Ok(Some(alloca.into())) + move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + + let func = match prim { + PrimDef::FunNpArctan2 => builtin_fns::call_numpy_arctan2, + PrimDef::FunNpCopysign => builtin_fns::call_numpy_copysign, + PrimDef::FunNpFmax => builtin_fns::call_numpy_fmax, + PrimDef::FunNpFmin => builtin_fns::call_numpy_fmin, + PrimDef::FunNpLdExp => builtin_fns::call_numpy_ldexp, + PrimDef::FunNpHypot => builtin_fns::call_numpy_hypot, + PrimDef::FunNpNextAfter => builtin_fns::call_numpy_nextafter, + _ => unreachable!(), + }; + + Ok(Some(func(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) }, )))), loc: None, - })), - ]; + } + } - let ast_list: Vec>> = - (0..top_level_def_list.len()).map(|_| None).collect(); + fn create_method(&self, prim: PrimDef, method_ty: Type) -> (StrRef, Type, DefinitionId) { + (prim.simple_name().into(), method_ty, prim.id()) + } - izip!(top_level_def_list, ast_list).collect_vec() + fn new_type_or_ndarray_ty(&mut self, scalar_ty: Type) -> (Type, u32) { + let ndarray = make_ndarray_ty(self.unifier, self.primitives, Some(scalar_ty), None); + + self.unifier.get_fresh_var_with_range(&[scalar_ty, ndarray], Some("T".into()), None) + } } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index f0d4fd3..08db476 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -3,41 +3,259 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; use crate::typecheck::typedef::{Mapping, VarMap}; use nac3parser::ast::{Constant, Location}; -use strum_macros::EnumIter; use strum::IntoEnumIterator; +use strum_macros::EnumIter; use super::*; /// All primitive types and functions in nac3core. -#[derive(Clone, Copy, Debug, EnumIter)] -pub enum PrimitiveDefinition { - Int32, // 0 - Int64, // 1 - Float, // 2 - Bool, // 3 - None, // 4 - Range, // 5 - Str, // 6 - Exception, // 7 - UInt32, // 8 - UInt64, // 9 - Option, // 10 - OptionIsSome, // 11 - OptionIsNone, // 12 - OptionUnwrap, // 13 - NDArray, // 14 - NDArrayCopy, // 15 - NDArrayFill, // 16 +#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)] +pub enum PrimDef { + Int32, + Int64, + Float, + Bool, + None, + Range, + Str, + Exception, + UInt32, + UInt64, + Option, + OptionIsSome, + OptionIsNone, + OptionUnwrap, + NDArray, + NDArrayCopy, + NDArrayFill, + FunInt32, + FunInt64, + FunUInt32, + FunUInt64, + FunFloat, + FunNpNDArray, + FunNpEmpty, + FunNpZeros, + FunNpOnes, + FunNpFull, + FunNpEye, + FunNpIdentity, + FunRound, + FunRound64, + FunNpRound, + FunRange, + FunStr, + FunBool, + FunFloor, + FunFloor64, + FunNpFloor, + FunCeil, + FunCeil64, + FunNpCeil, + FunLen, + FunMin, + FunNpMin, + FunNpMinimum, + FunMax, + FunNpMax, + FunNpMaximum, + FunAbs, + FunNpIsNan, + FunNpIsInf, + FunNpSin, + FunNpCos, + FunNpExp, + FunNpExp2, + FunNpLog, + FunNpLog10, + FunNpLog2, + FunNpFabs, + FunNpSqrt, + FunNpRint, + FunNpTan, + FunNpArcsin, + FunNpArccos, + FunNpArctan, + FunNpSinh, + FunNpCosh, + FunNpTanh, + FunNpArcsinh, + FunNpArccosh, + FunNpArctanh, + FunNpExpm1, + FunNpCbrt, + FunSpSpecErf, + FunSpSpecErfc, + FunSpSpecGamma, + FunSpSpecGammaln, + FunSpSpecJ0, + FunSpSpecJ1, + FunNpArctan2, + FunNpCopysign, + FunNpFmax, + FunNpFmin, + FunNpLdExp, + FunNpHypot, + FunNpNextAfter, + FunSome, } -impl PrimitiveDefinition { - pub fn id(self) -> DefinitionId { - return DefinitionId(self as usize); +/// Associated details of a [`PrimDef`] +struct PrimDefDetails { + name: &'static str, + simple_name: &'static str, +} + +impl PrimDef { + /// Get the assigned [`DefinitionId`] of this [`PrimDef`]. + /// + /// The assigned definition ID is defined by the position this [`PrimDef`] enum unit variant is defined at, + /// with the first `PrimDef`'s definition id being `0`. + pub fn id(&self) -> DefinitionId { + return DefinitionId(*self as usize); } + /// Check if a definition ID is that of a [`PrimDef`]. pub fn contains_id(id: DefinitionId) -> bool { Self::iter().any(|prim| prim.id() == id) } + + /// Get the definition "simple_name" of this [`PrimDef`]. + /// + /// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::simple_name`]. + /// + /// If the [`PrimDef`] is a class, this is equal to [`PrimDef::name`]. + pub fn simple_name<'a>(&'a self) -> &'static str { + self.details().simple_name + } + + /// Get the definition "name" of this [`PrimDef`]. + /// + /// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::name`]. + /// + /// If the [`PrimDef`] is a class, this corresponds to [`TopLevelDef::Class::name`]. + pub fn name<'a>(&'a self) -> &'static str { + self.details().name + } + + /// Get the associated details of this [`PrimDef`] + fn details(&self) -> PrimDefDetails { + use PrimDef::*; + + fn new(name: &'static str) -> PrimDefDetails { + PrimDefDetails { name, simple_name: name } + } + + fn new2(simple_name: &'static str, name: &'static str) -> PrimDefDetails { + PrimDefDetails { name, simple_name } + } + + match self { + Int32 => new("int32"), + Int64 => new("int64"), + Float => new("float"), + Bool => new("bool"), + None => new("none"), + Range => new("range"), + Str => new("str"), + Exception => new("Exception"), + UInt32 => new("uint32"), + UInt64 => new("uint64"), + Option => new("Option"), + OptionIsSome => new2("is_some", "Option.is_some"), + OptionIsNone => new2("is_none", "Option.is_none"), + OptionUnwrap => new2("unwrap", "Option.unwrap"), + NDArray => new("ndarray"), + NDArrayCopy => new2("copy", "ndarray.copy"), + NDArrayFill => new2("fill", "ndarray.fill"), + FunInt32 => new("int32"), + FunInt64 => new("int64"), + FunUInt32 => new("uint32"), + FunUInt64 => new("uint64"), + FunFloat => new("float"), + FunNpNDArray => new("np_ndarray"), + FunNpEmpty => new("np_empty"), + FunNpZeros => new("np_zeros"), + FunNpOnes => new("np_ones"), + FunNpFull => new("np_full"), + FunNpEye => new("np_eye"), + FunNpIdentity => new("np_identity"), + FunRound => new("round"), + FunRound64 => new("round64"), + FunNpRound => new("np_round"), + FunRange => new("range"), + FunStr => new("str"), + FunBool => new("bool"), + FunFloor => new("floor"), + FunFloor64 => new("floor64"), + FunNpFloor => new("np_floor"), + FunCeil => new("ceil"), + FunCeil64 => new("ceil64"), + FunNpCeil => new("np_ceil"), + FunLen => new("len"), + FunMin => new("min"), + FunNpMin => new("np_min"), + FunNpMinimum => new("np_minimum"), + FunMax => new("max"), + FunNpMax => new("np_max"), + FunNpMaximum => new("np_maximum"), + FunAbs => new("abs"), + FunNpIsNan => new("np_isnan"), + FunNpIsInf => new("np_isinf"), + FunNpSin => new("np_sin"), + FunNpCos => new("np_cos"), + FunNpExp => new("np_exp"), + FunNpExp2 => new("np_exp2"), + FunNpLog => new("np_log"), + FunNpLog10 => new("np_log10"), + FunNpLog2 => new("np_log2"), + FunNpFabs => new("np_fabs"), + FunNpSqrt => new("np_sqrt"), + FunNpRint => new("np_rint"), + FunNpTan => new("np_tan"), + FunNpArcsin => new("np_arcsin"), + FunNpArccos => new("np_arccos"), + FunNpArctan => new("np_arctan"), + FunNpSinh => new("np_sinh"), + FunNpCosh => new("np_cosh"), + FunNpTanh => new("np_tanh"), + FunNpArcsinh => new("np_arcsinh"), + FunNpArccosh => new("np_arccosh"), + FunNpArctanh => new("np_arctanh"), + FunNpExpm1 => new("np_expm1"), + FunNpCbrt => new("np_cbrt"), + FunSpSpecErf => new("sp_spec_erf"), + FunSpSpecErfc => new("sp_spec_erfc"), + FunSpSpecGamma => new("sp_spec_gamma"), + FunSpSpecGammaln => new("sp_spec_gammaln"), + FunSpSpecJ0 => new("sp_spec_j0"), + FunSpSpecJ1 => new("sp_spec_j1"), + FunNpArctan2 => new("np_arctan2"), + FunNpCopysign => new("np_copysign"), + FunNpFmax => new("np_fmax"), + FunNpFmin => new("np_fmin"), + FunNpLdExp => new("np_ldexp"), + FunNpHypot => new("np_hypot"), + FunNpNextAfter => new("np_nextafter"), + FunSome => new("Some"), + } + } +} + +/// Asserts that a [`PrimDef`] is in an allowlist. +/// +/// Like `debug_assert!`, this statements of this function are only +/// enabled if `cfg!(debug_assertions)` is true. +pub fn debug_assert_prim_is_allowed(prim: PrimDef, allowlist: [PrimDef; N]) { + if cfg!(debug_assertions) { + let allowed = allowlist.iter().any(|p| *p == prim); + if !allowed { + panic!( + "Disallowed primitive definition. Got {:?}, but expects it to be in {:?}", + prim, allowlist + ) + } + } } impl TopLevelDef { @@ -82,42 +300,42 @@ impl TopLevelComposer { pub fn make_primitives(size_t: u32) -> (PrimitiveStore, Unifier) { let mut unifier = Unifier::new(); let int32 = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Int32.id(), + obj_id: PrimDef::Int32.id(), fields: HashMap::new(), params: VarMap::new(), }); let int64 = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Int64.id(), + obj_id: PrimDef::Int64.id(), fields: HashMap::new(), params: VarMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Float.id(), + obj_id: PrimDef::Float.id(), fields: HashMap::new(), params: VarMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Bool.id(), + obj_id: PrimDef::Bool.id(), fields: HashMap::new(), params: VarMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::None.id(), + obj_id: PrimDef::None.id(), fields: HashMap::new(), params: VarMap::new(), }); let range = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Range.id(), + obj_id: PrimDef::Range.id(), fields: HashMap::new(), params: VarMap::new(), }); let str = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Str.id(), + obj_id: PrimDef::Str.id(), fields: HashMap::new(), params: VarMap::new(), }); let exception = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Exception.id(), + obj_id: PrimDef::Exception.id(), fields: vec![ ("__name__".into(), (int32, true)), ("__file__".into(), (str, true)), @@ -134,12 +352,12 @@ impl TopLevelComposer { params: VarMap::new(), }); let uint32 = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::UInt32.id(), + obj_id: PrimDef::UInt32.id(), fields: HashMap::new(), params: VarMap::new(), }); let uint64 = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::UInt64.id(), + obj_id: PrimDef::UInt64.id(), fields: HashMap::new(), params: VarMap::new(), }); @@ -156,7 +374,7 @@ impl TopLevelComposer { vars: VarMap::from([(option_type_var.1, option_type_var.0)]), })); let option = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Option.id(), + obj_id: PrimDef::Option.id(), fields: vec![ ("is_some".into(), (is_some_type_fun_ty, true)), ("is_none".into(), (is_some_type_fun_ty, true)), @@ -174,7 +392,8 @@ impl TopLevelComposer { }; let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); - let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None); + let ndarray_ndims_tvar = + unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None); let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None); let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], @@ -185,13 +404,11 @@ impl TopLevelComposer { ]), })); let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { - name: "value".into(), - ty: ndarray_dtype_tvar.0, - default_value: None, - }, - ], + args: vec![FuncArg { + name: "value".into(), + ty: ndarray_dtype_tvar.0, + default_value: None, + }], ret: none, vars: VarMap::from([ (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), @@ -199,7 +416,7 @@ impl TopLevelComposer { ]), })); let ndarray = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::NDArray.id(), + obj_id: PrimDef::NDArray.id(), fields: Mapping::from([ ("copy".into(), (ndarray_copy_fun_ty, true)), ("fill".into(), (ndarray_fill_fun_ty, true)), @@ -359,9 +576,7 @@ impl TopLevelComposer { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { Ok(*id) } else { - Err(HashSet::from([ - "not type var".to_string(), - ])) + Err(HashSet::from(["not type var".to_string()])) } } @@ -378,25 +593,27 @@ impl TopLevelComposer { let ( TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }), TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }), - ) = (this, other) else { + ) = (this, other) + else { unreachable!("this function must be called with function type") }; // check args - let args_ok = this_args - .iter() - .map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap())) - .zip(other_args.iter().map(|FuncArg { name, ty, .. }| { - (name, type_var_to_concrete_def.get(ty).unwrap()) - })) - .all(|(this, other)| { - if this.0 == &"self".into() && this.0 == other.0 { - true - } else { - this.0 == other.0 - && check_overload_type_annotation_compatible(this.1, other.1, unifier) - } - }); + let args_ok = + this_args + .iter() + .map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap())) + .zip(other_args.iter().map(|FuncArg { name, ty, .. }| { + (name, type_var_to_concrete_def.get(ty).unwrap()) + })) + .all(|(this, other)| { + if this.0 == &"self".into() && this.0 == other.0 { + true + } else { + this.0 == other.0 + && check_overload_type_annotation_compatible(this.1, other.1, unifier) + } + }); // check rets let ret_ok = check_overload_type_annotation_compatible( @@ -439,12 +656,10 @@ impl TopLevelComposer { } } => { - return Err(HashSet::from([ - format!( - "redundant type annotation for class fields at {}", - s.location - ), - ])) + return Err(HashSet::from([format!( + "redundant type annotation for class fields at {}", + s.location + )])) } ast::StmtKind::Assign { targets, .. } => { for t in targets { @@ -568,93 +783,89 @@ pub fn parse_parameter_default_value( Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( tuple.iter().map(|x| handle_constant(x, loc)).collect::, _>>()?, )), - Constant::None => Err(HashSet::from([ - format!( - "`None` is not supported, use `none` for option type instead ({loc})" - ), - ])), + Constant::None => Err(HashSet::from([format!( + "`None` is not supported, use `none` for option type instead ({loc})" + )])), _ => unimplemented!("this constant is not supported at {}", loc), } } match &default.node { ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location), - ast::ExprKind::Call { func, args, .. } if args.len() == 1 => { - match &func.node { - ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node { - ast::ExprKind::Constant { value: Constant::Int(v), .. } => { - let v: Result = (*v).try_into(); - match v { - Ok(v) => Ok(SymbolValue::I64(v)), - _ => Err(HashSet::from([ - format!("default param value out of range at {}", default.location) - ])), - } + ast::ExprKind::Call { func, args, .. } if args.len() == 1 => match &func.node { + ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node { + ast::ExprKind::Constant { value: Constant::Int(v), .. } => { + let v: Result = (*v).try_into(); + match v { + Ok(v) => Ok(SymbolValue::I64(v)), + _ => Err(HashSet::from([format!( + "default param value out of range at {}", + default.location + )])), } - _ => Err(HashSet::from([ - format!("only allow constant integer here at {}", default.location), - ])) } - ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node { - ast::ExprKind::Constant { value: Constant::Int(v), .. } => { - let v: Result = (*v).try_into(); - match v { - Ok(v) => Ok(SymbolValue::U32(v)), - _ => Err(HashSet::from([ - format!("default param value out of range at {}", default.location), - ])), - } + _ => Err(HashSet::from([format!( + "only allow constant integer here at {}", + default.location + )])), + }, + ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node { + ast::ExprKind::Constant { value: Constant::Int(v), .. } => { + let v: Result = (*v).try_into(); + match v { + Ok(v) => Ok(SymbolValue::U32(v)), + _ => Err(HashSet::from([format!( + "default param value out of range at {}", + default.location + )])), } - _ => Err(HashSet::from([ - format!("only allow constant integer here at {}", default.location), - ])) } - ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node { - ast::ExprKind::Constant { value: Constant::Int(v), .. } => { - let v: Result = (*v).try_into(); - match v { - Ok(v) => Ok(SymbolValue::U64(v)), - _ => Err(HashSet::from([ - format!("default param value out of range at {}", default.location), - ])), - } + _ => Err(HashSet::from([format!( + "only allow constant integer here at {}", + default.location + )])), + }, + ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node { + ast::ExprKind::Constant { value: Constant::Int(v), .. } => { + let v: Result = (*v).try_into(); + match v { + Ok(v) => Ok(SymbolValue::U64(v)), + _ => Err(HashSet::from([format!( + "default param value out of range at {}", + default.location + )])), } - _ => Err(HashSet::from([ - format!("only allow constant integer here at {}", default.location), - ])) } - ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok( - SymbolValue::OptionSome( - Box::new(parse_parameter_default_value(&args[0], resolver)?) - ) - ), - _ => Err(HashSet::from([ - format!("unsupported default parameter at {}", default.location), - ])), - } - } - ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts - .iter() - .map(|x| parse_parameter_default_value(x, resolver)) - .collect::, _>>()? + _ => Err(HashSet::from([format!( + "only allow constant integer here at {}", + default.location + )])), + }, + ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok(SymbolValue::OptionSome( + Box::new(parse_parameter_default_value(&args[0], resolver)?), + )), + _ => Err(HashSet::from([format!( + "unsupported default parameter at {}", + default.location + )])), + }, + ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple( + elts.iter() + .map(|x| parse_parameter_default_value(x, resolver)) + .collect::, _>>()?, )), ast::ExprKind::Name { id, .. } if id == &"none".into() => Ok(SymbolValue::OptionNone), ast::ExprKind::Name { id, .. } => { - resolver.get_default_param_value(default).ok_or_else( - || HashSet::from([ - format!( - "`{}` cannot be used as a default parameter at {} \ + resolver.get_default_param_value(default).ok_or_else(|| { + HashSet::from([format!( + "`{}` cannot be used as a default parameter at {} \ (not primitive type, option or tuple / not defined?)", - id, - default.location - ), - ]) - ) + id, default.location + )]) + }) } - _ => Err(HashSet::from([ - format!( - "unsupported default parameter (not primitive type, option or tuple) at {}", - default.location - ), - ])) + _ => Err(HashSet::from([format!( + "unsupported default parameter (not primitive type, option or tuple) at {}", + default.location + )])), } } diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index d4f4da9..8ca4f1c 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -1,6 +1,6 @@ use itertools::Itertools; use crate::{ - toplevel::helper::PrimitiveDefinition, + toplevel::helper::PrimDef, typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, TypeEnum, Unifier, VarMap}, @@ -37,7 +37,7 @@ pub fn subst_ndarray_tvars( let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; - debug_assert_eq!(*obj_id, PrimitiveDefinition::NDArray.id()); + debug_assert_eq!(*obj_id, PrimDef::NDArray.id()); if dtype.is_none() && ndims.is_none() { return ndarray @@ -66,7 +66,7 @@ fn unpack_ndarray_tvars( let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; - debug_assert_eq!(*obj_id, PrimitiveDefinition::NDArray.id()); + debug_assert_eq!(*obj_id, PrimDef::NDArray.id()); debug_assert_eq!(params.len(), 2); params.iter() diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 073f640..c78b487 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -1,5 +1,5 @@ use crate::symbol_resolver::SymbolValue; -use crate::toplevel::helper::PrimitiveDefinition; +use crate::toplevel::helper::PrimDef; use crate::typecheck::typedef::VarMap; use super::*; use nac3parser::ast::Constant; @@ -95,7 +95,7 @@ pub fn parse_ast_to_type_annotation_kinds( } else if id == &"str".into() { Ok(TypeAnnotation::Primitive(primitives.str)) } else if id == &"Exception".into() { - Ok(TypeAnnotation::CustomClass { id: PrimitiveDefinition::Exception.id(), params: Vec::default() }) + Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() }) } else if let Ok(obj_id) = resolver.get_identifier_def(*id) { let type_vars = { let def_read = top_level_defs[obj_id.0].try_read(); diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index dbb6f49..721103c 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -1,6 +1,6 @@ use std::cmp::max; use crate::symbol_resolver::SymbolValue; -use crate::toplevel::helper::PrimitiveDefinition; +use crate::toplevel::helper::PrimDef; use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; use crate::typecheck::{ type_inferencer::*, @@ -354,8 +354,8 @@ pub fn typeof_ndarray_broadcast( left: Type, right: Type, ) -> Result { - let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); - let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); assert!(is_left_ndarray || is_right_ndarray); @@ -425,8 +425,8 @@ pub fn typeof_binop( lhs: Type, rhs: Type, ) -> Result, String> { - let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); - let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); Ok(Some(match op { Operator::Add @@ -524,16 +524,16 @@ pub fn typeof_unaryop( Ok(match *op { Unaryop::Not => { match operand_obj_id { - Some(v) if v == PrimitiveDefinition::NDArray.id() => Some(operand), + Some(v) if v == PrimDef::NDArray.id() => Some(operand), Some(_) => Some(primitives.bool), _ => None } } Unaryop::Invert => { - if operand_obj_id.is_some_and(|id| id == PrimitiveDefinition::Bool.id()) { + if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) { Some(primitives.int32) - } else if operand_obj_id.is_some_and(|id| PrimitiveDefinition::iter().any(|prim| id == prim.id())) { + } else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) { Some(operand) } else { None @@ -542,9 +542,9 @@ pub fn typeof_unaryop( Unaryop::UAdd | Unaryop::USub => { - if operand_obj_id.is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) { + if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) { let (dtype, _) = unpack_ndarray_var_tys(unifier, operand); - if dtype.obj_id(unifier).is_some_and(|id| id == PrimitiveDefinition::Bool.id()) { + if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) { return Err(if *op == Unaryop::UAdd { "The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string() } else { @@ -553,9 +553,9 @@ pub fn typeof_unaryop( } Some(operand) - } else if operand_obj_id.is_some_and(|id| id == PrimitiveDefinition::Bool.id()) { + } else if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) { Some(primitives.int32) - } else if operand_obj_id.is_some_and(|id| PrimitiveDefinition::iter().any(|prim| id == prim.id())) { + } else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) { Some(operand) } else { None @@ -574,10 +574,10 @@ pub fn typeof_cmpop( ) -> Result, String> { let is_left_ndarray = lhs .obj_id(unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); let is_right_ndarray = rhs .obj_id(unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()); + .is_some_and(|id| id == PrimDef::NDArray.id()); Ok(Some(if is_left_ndarray || is_right_ndarray { let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?; diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 9007b38..3993bb5 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -9,7 +9,7 @@ use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ - helper::PrimitiveDefinition, + helper::PrimDef, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelContext, }, @@ -235,7 +235,7 @@ impl<'a> Fold<()> for Inferencer<'a> { } else { let list_like_ty = match &*self.unifier.get_ty(iter.custom.unwrap()) { TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }), - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimitiveDefinition::NDArray.id() => todo!(), + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => todo!(), _ => unreachable!(), }; self.unify(list_like_ty, iter.custom.unwrap(), &iter.location)?; @@ -877,7 +877,7 @@ impl<'a> Inferencer<'a> { let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); - let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) { + let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) @@ -919,7 +919,7 @@ impl<'a> Inferencer<'a> { let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); - let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) { + let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty); ndarray_dtype @@ -970,13 +970,13 @@ impl<'a> Inferencer<'a> { let arg1 = self.fold_expr(args.remove(0))?; let arg1_ty = arg1.custom.unwrap(); - let arg0_dtype = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) { + let arg0_dtype = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { unpack_ndarray_var_tys(self.unifier, arg0_ty).0 } else { arg0_ty }; - let arg1_dtype = if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) { + let arg1_dtype = if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { unpack_ndarray_var_tys(self.unifier, arg1_ty).0 } else { arg1_ty @@ -1007,11 +1007,11 @@ impl<'a> Inferencer<'a> { let ret = if [ &arg0_ty, &arg1_ty, - ].into_iter().any(|arg_ty| arg_ty.obj_id(self.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id())) { + ].into_iter().any(|arg_ty| arg_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) { // typeof_ndarray_broadcast requires both dtypes to be the same, but ldexp accepts // (float, int32), so convert it to align with the dtype of the first arg let arg1_ty = if id == &"np_ldexp".into() { - if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) { + if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty); make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims)) @@ -1108,7 +1108,7 @@ impl<'a> Inferencer<'a> { let arg0 = self.fold_expr(args.remove(0))?; let arg0_ty = arg0.custom.unwrap(); - let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) { + let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty); make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims)) @@ -1500,7 +1500,7 @@ impl<'a> Inferencer<'a> { ops: &[ast::Cmpop], comparators: &[ast::Expr>], ) -> InferenceResult { - if ops.len() > 1 && once(left).chain(comparators).any(|expr| expr.custom.unwrap().obj_id(self.unifier).is_some_and(|id| id == PrimitiveDefinition::NDArray.id())) { + if ops.len() > 1 && once(left).chain(comparators).any(|expr| expr.custom.unwrap().obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())) { return Err(HashSet::from([String::from("Comparator chaining with ndarray types not supported")])) } @@ -1614,7 +1614,7 @@ impl<'a> Inferencer<'a> { } let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimitiveDefinition::NDArray.id() => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)) @@ -1627,7 +1627,7 @@ impl<'a> Inferencer<'a> { } ExprKind::Constant { value: ast::Constant::Int(val), .. } => { match &*self.unifier.get_ty(value.custom.unwrap()) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimitiveDefinition::NDArray.id() => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); self.infer_subscript_ndarray(value, ty, ndims) } @@ -1650,7 +1650,7 @@ impl<'a> Inferencer<'a> { if value.custom .unwrap() .obj_id(self.unifier) - .is_some_and(|id| id == PrimitiveDefinition::NDArray.id()) + .is_some_and(|id| id == PrimDef::NDArray.id()) .not() { return report_error("Tuple slices are only supported for ndarrays", slice.location) } @@ -1683,7 +1683,7 @@ impl<'a> Inferencer<'a> { self.constrain(value.custom.unwrap(), list, &value.location)?; Ok(ty) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimitiveDefinition::NDArray.id() => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); let valid_index_tys = [ diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 0062a54..7203b1c 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -3,7 +3,7 @@ use super::*; use crate::{ codegen::CodeGenContext, symbol_resolver::ValueEnum, - toplevel::{DefinitionId, helper::PrimitiveDefinition, TopLevelDef}, + toplevel::{DefinitionId, helper::PrimDef, TopLevelDef}, }; use indoc::indoc; use std::iter::zip; @@ -73,7 +73,7 @@ impl TestEnvironment { let mut unifier = Unifier::new(); let int32 = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Int32.id(), + obj_id: PrimDef::Int32.id(), fields: HashMap::new(), params: VarMap::new(), }); @@ -86,59 +86,59 @@ impl TestEnvironment { fields.insert("__add__".into(), (add_ty, false)); }); let int64 = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Int64.id(), + obj_id: PrimDef::Int64.id(), fields: HashMap::new(), params: VarMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Float.id(), + obj_id: PrimDef::Float.id(), fields: HashMap::new(), params: VarMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Bool.id(), + obj_id: PrimDef::Bool.id(), fields: HashMap::new(), params: VarMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::None.id(), + obj_id: PrimDef::None.id(), fields: HashMap::new(), params: VarMap::new(), }); let range = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Range.id(), + obj_id: PrimDef::Range.id(), fields: HashMap::new(), params: VarMap::new(), }); let str = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Str.id(), + obj_id: PrimDef::Str.id(), fields: HashMap::new(), params: VarMap::new(), }); let exception = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Exception.id(), + obj_id: PrimDef::Exception.id(), fields: HashMap::new(), params: VarMap::new(), }); let uint32 = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::UInt32.id(), + obj_id: PrimDef::UInt32.id(), fields: HashMap::new(), params: VarMap::new(), }); let uint64 = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::UInt64.id(), + obj_id: PrimDef::UInt64.id(), fields: HashMap::new(), params: VarMap::new(), }); let option = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Option.id(), + obj_id: PrimDef::Option.id(), fields: HashMap::new(), params: VarMap::new(), }); let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None); let ndarray = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::NDArray.id(), + obj_id: PrimDef::NDArray.id(), fields: HashMap::new(), params: VarMap::from([ (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), @@ -211,7 +211,7 @@ impl TestEnvironment { let mut identifier_mapping = HashMap::new(); let mut top_level_defs: Vec>> = Vec::new(); let int32 = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Int32.id(), + obj_id: PrimDef::Int32.id(), fields: HashMap::new(), params: VarMap::new(), }); @@ -224,57 +224,57 @@ impl TestEnvironment { fields.insert("__add__".into(), (add_ty, false)); }); let int64 = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Int64.id(), + obj_id: PrimDef::Int64.id(), fields: HashMap::new(), params: VarMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Float.id(), + obj_id: PrimDef::Float.id(), fields: HashMap::new(), params: VarMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Bool.id(), + obj_id: PrimDef::Bool.id(), fields: HashMap::new(), params: VarMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::None.id(), + obj_id: PrimDef::None.id(), fields: HashMap::new(), params: VarMap::new(), }); let range = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Range.id(), + obj_id: PrimDef::Range.id(), fields: HashMap::new(), params: VarMap::new(), }); let str = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Str.id(), + obj_id: PrimDef::Str.id(), fields: HashMap::new(), params: VarMap::new(), }); let exception = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Exception.id(), + obj_id: PrimDef::Exception.id(), fields: HashMap::new(), params: VarMap::new(), }); let uint32 = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::UInt32.id(), + obj_id: PrimDef::UInt32.id(), fields: HashMap::new(), params: VarMap::new(), }); let uint64 = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::UInt64.id(), + obj_id: PrimDef::UInt64.id(), fields: HashMap::new(), params: VarMap::new(), }); let option = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::Option.id(), + obj_id: PrimDef::Option.id(), fields: HashMap::new(), params: VarMap::new(), }); let ndarray = unifier.add_ty(TypeEnum::TObj { - obj_id: PrimitiveDefinition::NDArray.id(), + obj_id: PrimDef::NDArray.id(), fields: HashMap::new(), params: VarMap::new(), });