From 727a1886b3ba684823f7d2936b62fe6afe783a83 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 27 Mar 2024 17:06:58 +0800 Subject: [PATCH] core: Implement elementwise unary operators --- nac3core/src/codegen/expr.rs | 29 +++++++- nac3core/src/codegen/numpy.rs | 67 +++++++++++++++++++ ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/magic_methods.rs | 19 ++++++ nac3core/src/typecheck/type_inferencer/mod.rs | 15 ++++- nac3standalone/demo/src/ndarray.py | 50 ++++++++++++++ 10 files changed, 184 insertions(+), 10 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3f8bafe..2c05d66 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1296,7 +1296,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( /// Generates LLVM IR for a unary operator expression using the [`Type`] and /// [LLVM value][`BasicValueEnum`] of the operands. pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( - _generator: &mut G, + generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, op: &ast::Unaryop, operand: (&Option, BasicValueEnum<'ctx>), @@ -1336,6 +1336,33 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( .unwrap(), _ => val.into(), } + } else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + let llvm_usize = generator.get_size_type(ctx.ctx); + let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty); + + let val = NDArrayValue::from_ptr_val( + val.into_pointer_value(), + llvm_usize, + None, + ); + + let res = numpy::ndarray_elementwise_unaryop_impl( + generator, + ctx, + ndarray_dtype, + None, + val, + |generator, ctx, val| { + gen_unaryop_expr_with_values( + generator, + ctx, + op, + (&Some(ndarray_dtype), val) + )?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype) + }, + )?; + + res.as_ptr_value().into() } else { unimplemented!() })) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 454f238..82c86a5 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -340,6 +340,31 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>( ) } +fn ndarray_fill_mapping<'ctx, G, MapFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + src: NDArrayValue<'ctx>, + dest: NDArrayValue<'ctx>, + map_fn: MapFn, +) -> Result<(), String> + where + G: CodeGenerator + ?Sized, + MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> Result, String>, +{ + ndarray_fill_flattened( + generator, + ctx, + dest, + |generator, ctx, i| { + let elem = unsafe { + src.data().get_unchecked(ctx, generator, i, None) + }; + + map_fn(generator, ctx, elem) + }, + ) +} + /// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of /// the target `ndarray`. fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( @@ -656,6 +681,48 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( Ok(ndarray) } +pub fn ndarray_elementwise_unaryop_impl<'ctx, G, MapFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + res: Option>, + operand: NDArrayValue<'ctx>, + map_fn: MapFn, +) -> Result, String> + where + G: CodeGenerator, + MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> Result, String>, +{ + let res = res.unwrap_or_else(|| { + create_ndarray_dyn_shape( + generator, + ctx, + elem_ty, + &operand, + |_, ctx, v| { + Ok(v.load_ndims(ctx)) + }, + |generator, ctx, v, idx| { + unsafe { + Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, idx, None)) + } + }, + ).unwrap() + }); + + ndarray_fill_mapping( + generator, + ctx, + operand, + res, + |generator, ctx, elem| { + map_fn(generator, ctx, elem) + } + )?; + + Ok(res) +} + /// LLVM-typed implementation for computing elementwise binary operations on two input operands. /// /// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 498e3f1..a9dc4ad 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [124]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [127]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 8454bfb..dc36b54 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar113]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar113\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar116]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar116\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index ee506c1..d6adcee 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [126]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [131]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [129]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [134]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 16159e4..55767a8 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar112, typevar113]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar112\", \"typevar113\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar115, typevar116]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar115\", \"typevar116\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index dfda3a8..f4f96f2 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [132]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [135]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [140]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [143]\n}\n", ] diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index bfd137d..38d480c 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -466,6 +466,23 @@ pub fn typeof_binop( })) } +pub fn typeof_unaryop( + unifier: &mut Unifier, + primitives: &PrimitiveStore, + op: &Unaryop, + operand: Type, +) -> Result, String> { + if *op == Unaryop::Not && operand.obj_id(unifier).is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) { + return Err("The truth value of an array with more than one element is ambiguous".to_string()) + } + + Ok(if operand.obj_id(unifier).is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) { + Some(operand) + } else { + None + }) +} + pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { let PrimitiveStore { int32: int32_t, @@ -525,4 +542,6 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); + impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); + impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 1b28a24..c19d3d1 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -549,7 +549,9 @@ impl<'a> Fold<()> for Inferencer<'a> { ExprKind::BinOp { left, op, right } => { Some(self.infer_bin_ops(expr.location, left, op, right, false)?) } - ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?), + ExprKind::UnaryOp { op, operand } => { + Some(self.infer_unary_ops(expr.location, op, operand)?) + } ExprKind::Compare { left, ops, comparators } => { Some(self.infer_compare(left, ops, comparators)?) } @@ -1247,11 +1249,20 @@ impl<'a> Inferencer<'a> { fn infer_unary_ops( &mut self, + location: Location, op: &ast::Unaryop, operand: &ast::Expr>, ) -> InferenceResult { let method = unaryop_name(op).into(); - self.build_method_call(operand.location, method, operand.custom.unwrap(), vec![], None) + + let ret = typeof_unaryop( + self.unifier, + self.primitives, + op, + operand.custom.unwrap(), + ).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?; + + self.build_method_call(operand.location, method, operand.custom.unwrap(), vec![], ret) } fn infer_compare( diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 87f0b36..4543097 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1,3 +1,7 @@ +@extern +def output_bool(x: bool): + ... + @extern def output_int32(x: int32): ... @@ -6,10 +10,20 @@ def output_int32(x: int32): def output_float64(x: float): ... +def output_ndarray_bool_2(n: ndarray[bool, Literal[2]]): + for r in range(len(n)): + for c in range(len(n[r])): + output_bool(n[r][c]) + def output_ndarray_int32_1(n: ndarray[int32, Literal[1]]): for i in range(len(n)): output_int32(n[i]) +def output_ndarray_int32_2(n: ndarray[int32, Literal[2]]): + for r in range(len(n)): + for c in range(len(n[r])): + output_int32(n[r][c]) + def output_ndarray_float_1(n: ndarray[float, Literal[1]]): for i in range(len(n)): output_float64(n[i]) @@ -408,6 +422,39 @@ def test_ndarray_ipow_broadcast_scalar(): output_ndarray_float_2(x) +def test_ndarray_pos(): + x_int32 = np_full([2, 2], -2) + y_int32 = +x_int32 + + output_ndarray_int32_2(x_int32) + output_ndarray_int32_2(y_int32) + + x_float = np_full([2, 2], -2.0) + y_float = +x_float + + output_ndarray_float_2(x_float) + output_ndarray_float_2(y_float) + +def test_ndarray_neg(): + x_int32 = np_full([2, 2], -2) + y_int32 = -x_int32 + + output_ndarray_int32_2(x_int32) + output_ndarray_int32_2(y_int32) + + x_float = np_full([2, 2], 2.0) + y_float = -x_float + + output_ndarray_float_2(x_float) + output_ndarray_float_2(y_float) + +def test_ndarray_inv(): + x_int32 = np_full([2, 2], -2) + y_int32 = ~x_int32 + + output_ndarray_int32_2(x_int32) + output_ndarray_int32_2(y_int32) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -467,5 +514,8 @@ def run() -> int32: test_ndarray_ipow() test_ndarray_ipow_broadcast() test_ndarray_ipow_broadcast_scalar() + test_ndarray_pos() + test_ndarray_neg() + test_ndarray_inv() return 0