diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 4ff40c0..e0067bb 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,11 +1,12 @@ -use inkwell::{FloatPredicate, IntPredicate}; +use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use inkwell::types::BasicTypeEnum; use inkwell::values::BasicValueEnum; use itertools::Itertools; use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics, numpy}; -use crate::codegen::classes::NDArrayValue; +use crate::codegen::classes::{NDArrayValue, UntypedArrayLikeAccessor}; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; +use crate::codegen::stmt::gen_for_callback_incrementing; use crate::toplevel::helper::PRIMITIVE_DEF_IDS; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::Type; @@ -705,6 +706,92 @@ pub fn call_min<'ctx>( } } +/// Invokes the `np_min` builtin function. +pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_min"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (a_ty, a) = a; + + Ok(match a { + BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { + debug_assert!([ + ctx.primitives.bool, + ctx.primitives.int32, + ctx.primitives.uint32, + ctx.primitives.int64, + ctx.primitives.uint64, + ctx.primitives.float, + ].iter().any(|ty| ctx.unifier.unioned(a_ty, *ty))); + + a + } + + BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + + let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); + let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes()); + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { + let n_sz_eqz = ctx.builder + .build_int_compare( + IntPredicate::NE, + n_sz, + n_sz.get_type().const_zero(), + "", + ) + .unwrap(); + + ctx.make_assert( + generator, + n_sz_eqz, + "0:ValueError", + "zero-size array to reduction operation minimum which has no identity", + [None, None, None], + ctx.current_loc, + ); + } + + let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; + unsafe { + let identity = n.data() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); + ctx.builder.build_store(accumulator_addr, identity).unwrap(); + } + + gen_for_callback_incrementing( + generator, + ctx, + llvm_usize.const_int(1, false), + (n_sz, false), + |generator, ctx, idx| { + let elem = unsafe { + n.data().get_unchecked(ctx, generator, &idx, None) + }; + + let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); + let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)); + ctx.builder.build_store(accumulator_addr, result).unwrap(); + + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + + let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); + accumulator + } + + _ => unsupported_type(ctx, FN_NAME, &[a_ty]) + }) +} + /// Invokes the `max` builtin function. pub fn call_max<'ctx>( ctx: &mut CodeGenContext<'ctx, '_>, @@ -752,6 +839,92 @@ pub fn call_max<'ctx>( } } +/// Invokes the `np_max` builtin function. +pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_max"; + + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (a_ty, a) = a; + + Ok(match a { + BasicValueEnum::IntValue(_) | BasicValueEnum::FloatValue(_) => { + debug_assert!([ + ctx.primitives.bool, + ctx.primitives.int32, + ctx.primitives.uint32, + ctx.primitives.int64, + ctx.primitives.uint64, + ctx.primitives.float, + ].iter().any(|ty| ctx.unifier.unioned(a_ty, *ty))); + + a + } + + BasicValueEnum::PointerValue(n) if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => { + let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); + let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + + let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); + let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes()); + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { + let n_sz_eqz = ctx.builder + .build_int_compare( + IntPredicate::NE, + n_sz, + n_sz.get_type().const_zero(), + "", + ) + .unwrap(); + + ctx.make_assert( + generator, + n_sz_eqz, + "0:ValueError", + "zero-size array to reduction operation minimum which has no identity", + [None, None, None], + ctx.current_loc, + ); + } + + let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; + unsafe { + let identity = n.data() + .get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); + ctx.builder.build_store(accumulator_addr, identity).unwrap(); + } + + gen_for_callback_incrementing( + generator, + ctx, + llvm_usize.const_int(1, false), + (n_sz, false), + |generator, ctx, idx| { + let elem = unsafe { + n.data().get_unchecked(ctx, generator, &idx, None) + }; + + let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); + let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)); + ctx.builder.build_store(accumulator_addr, result).unwrap(); + + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + + let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); + accumulator + } + + _ => unsupported_type(ctx, FN_NAME, &[a_ty]) + }) +} + /// Invokes the `abs` builtin function. pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e057db9..7187ae8 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,4 +1,14 @@ -use super::*; +use std::iter::once; + +use indexmap::IndexMap; +use inkwell::{ + attributes::{Attribute, AttributeLoc}, + IntPredicate, + types::{BasicMetadataTypeEnum, BasicType}, + values::{BasicMetadataValueEnum, BasicValue, CallSiteValue} +}; +use itertools::Either; + use crate::{ codegen::{ builtin_fns, @@ -15,13 +25,8 @@ use crate::{ }, typecheck::typedef::VarMap, }; -use inkwell::{ - attributes::{Attribute, AttributeLoc}, - types::{BasicType, BasicMetadataTypeEnum}, - values::{BasicValue, BasicMetadataValueEnum, CallSiteValue}, - IntPredicate -}; -use itertools::Either; + +use super::*; type BuiltinInfo = Vec<(Arc>, Option)>; @@ -1378,6 +1383,28 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built )))), loc: None, })), + { + let ret_ty = unifier.get_fresh_var(Some("R".into()), None); + let var_map = num_or_ndarray_var_map.clone() + .into_iter() + .chain(once((ret_ty.1, ret_ty.0))) + .collect::>(); + + create_fn_by_codegen( + unifier, + &var_map, + "np_min", + ret_ty.0, + &[(float_or_ndarray_ty.0, "a")], + Box::new(|ctx, _, fun, args, generator| { + let a_ty = fun.0.args[0].ty; + let a = args[0].1.clone() + .to_basic_value_enum(ctx, generator, a_ty)?; + + Ok(Some(builtin_fns::call_numpy_min(generator, ctx, (a_ty, a))?)) + }), + ) + }, Arc::new(RwLock::new(TopLevelDef::Function { name: "max".into(), simple_name: "max".into(), @@ -1405,6 +1432,28 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built )))), loc: None, })), + { + let ret_ty = unifier.get_fresh_var(Some("R".into()), None); + let var_map = num_or_ndarray_var_map.clone() + .into_iter() + .chain(once((ret_ty.1, ret_ty.0))) + .collect::>(); + + create_fn_by_codegen( + unifier, + &var_map, + "np_max", + ret_ty.0, + &[(float_or_ndarray_ty.0, "a")], + Box::new(|ctx, _, fun, args, generator| { + let a_ty = fun.0.args[0].ty; + let a = args[0].1.clone() + .to_basic_value_enum(ctx, generator, a_ty)?; + + Ok(Some(builtin_fns::call_numpy_max(generator, ctx, (a_ty, a))?)) + }), + ) + }, Arc::new(RwLock::new(TopLevelDef::Function { name: "abs".into(), simple_name: "abs".into(), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 027ad42..65ceb24 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [222]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [224]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 248aa92..2ee8e8f 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar211]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar211\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar213]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar213\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index c79adf8..40b9688 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [224]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [229]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [226]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [231]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 9688485..7204c0d 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar210, typevar211]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar210\", \"typevar211\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar212, typevar213]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar212\", \"typevar213\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index a217b00..1a526c5 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [230]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [232]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [238]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [240]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 24006d8..61c56c6 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -908,6 +908,48 @@ impl<'a> Inferencer<'a> { })) } + if [ + "np_min", + "np_max", + ].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 { + let arg0 = self.fold_expr(args.remove(0))?; + let arg0_ty = arg0.custom.unwrap(); + + let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty); + + ndarray_dtype + } else { + arg0_ty + }; + + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "a".into(), + ty: arg0.custom.unwrap(), + default_value: None, + }, + ], + ret, + vars: VarMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: ctx.clone() }, + }), + args: vec![arg0], + keywords: vec![], + }, + })) + } + if [ "np_arctan2", "np_copysign", diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index f38e70b..b5fb153 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -174,6 +174,8 @@ def patch(module): # NumPy Math functions module.np_isnan = np.isnan module.np_isinf = np.isinf + module.np_min = np.min + module.np_max = np.max module.np_sin = np.sin module.np_cos = np.cos module.np_exp = np.exp diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 15afd5d..66f6fe9 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -759,6 +759,20 @@ def test_ndarray_ceil(): output_ndarray_int64_2(xf64) output_ndarray_float_2(xff) +def test_ndarray_min(): + x = np_identity(2) + y = np_min(x) + + output_ndarray_float_2(x) + output_float64(y) + +def test_ndarray_max(): + x = np_identity(2) + y = np_max(x) + + output_ndarray_float_2(x) + output_float64(y) + def test_ndarray_abs(): x = np_identity(2) y = abs(x) @@ -1363,6 +1377,8 @@ def run() -> int32: test_ndarray_round() test_ndarray_floor() + test_ndarray_min() + test_ndarray_max() test_ndarray_abs() test_ndarray_isnan() test_ndarray_isinf()