diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 17460250..2f46cd1a 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1905,15 +1905,23 @@ pub fn gen_ndarray_eye<'ctx>( )) }?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - nrows_arg.into_int_value(), - ncols_arg.into_int_value(), - offset_arg.into_int_value(), - ) - .map(NDArrayValue::into) + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); + + let nrows = Int(Int32) + .check_value(generator, context.ctx, nrows_arg) + .unwrap() + .s_extend_or_bit_cast(generator, context, SizeT); + let ncols = Int(Int32) + .check_value(generator, context.ctx, ncols_arg) + .unwrap() + .s_extend_or_bit_cast(generator, context, SizeT); + let offset = Int(Int32) + .check_value(generator, context.ctx, offset_arg) + .unwrap() + .s_extend_or_bit_cast(generator, context, SizeT); + + let ndarray = NDArrayObject::make_np_eye(generator, context, dtype, nrows, ncols, offset); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.identity`. @@ -1927,20 +1935,15 @@ pub fn gen_ndarray_identity<'ctx>( assert!(obj.is_none()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); + let (dtype, _) = unpack_ndarray_var_tys(&mut context.unifier, fun.0.ret); let n_ty = fun.0.args[0].ty; let n_arg = args[0].1.clone().to_basic_value_enum(context, generator, n_ty)?; - call_ndarray_eye_impl( - generator, - context, - context.primitives.float, - n_arg.into_int_value(), - n_arg.into_int_value(), - llvm_usize.const_zero(), - ) - .map(NDArrayValue::into) + let n = Int(Int32).check_value(generator, context.ctx, n_arg).unwrap(); + let n = n.s_extend_or_bit_cast(generator, context, SizeT); + let ndarray = NDArrayObject::make_np_identity(generator, context, dtype, n); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.copy`. @@ -1954,20 +1957,14 @@ pub fn gen_ndarray_copy<'ctx>( assert!(obj.is_some()); assert!(args.is_empty()); - let llvm_usize = generator.get_size_type(context.ctx); - let this_ty = obj.as_ref().unwrap().0; - let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); let this_arg = obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; - ndarray_copy_impl( - generator, - context, - this_elem_ty, - NDArrayValue::from_ptr_val(this_arg.into_pointer_value(), llvm_usize, None), - ) - .map(NDArrayValue::into) + let this = AnyObject { value: this_arg, ty: this_ty }; + let this = NDArrayObject::from_object(generator, context, this); + let ndarray = this.make_copy(generator, context); + Ok(ndarray.instance.value) } /// Generates LLVM IR for `ndarray.fill`. @@ -1981,48 +1978,15 @@ pub fn gen_ndarray_fill<'ctx>( assert!(obj.is_some()); assert_eq!(args.len(), 1); - let llvm_usize = generator.get_size_type(context.ctx); - let this_ty = obj.as_ref().unwrap().0; - let this_arg = obj - .as_ref() - .unwrap() - .1 - .clone() - .to_basic_value_enum(context, generator, this_ty)? - .into_pointer_value(); + let this_arg = + obj.as_ref().unwrap().1.clone().to_basic_value_enum(context, generator, this_ty)?; let value_ty = fun.0.args[0].ty; let value_arg = args[0].1.clone().to_basic_value_enum(context, generator, value_ty)?; - ndarray_fill_flattened( - generator, - context, - NDArrayValue::from_ptr_val(this_arg, llvm_usize, None), - |generator, ctx, _| { - let value = if value_arg.is_pointer_value() { - let llvm_i1 = ctx.ctx.bool_type(); - - let copy = generator.gen_var_alloc(ctx, value_arg.get_type(), None)?; - - call_memcpy_generic( - ctx, - copy, - value_arg.into_pointer_value(), - value_arg.get_type().size_of().map(Into::into).unwrap(), - llvm_i1.const_zero(), - ); - - copy.into() - } else if value_arg.is_int_value() || value_arg.is_float_value() { - value_arg - } else { - codegen_unreachable!(ctx) - }; - - Ok(value) - }, - )?; - + let this = AnyObject { value: this_arg, ty: this_ty }; + let this = NDArrayObject::from_object(generator, context, this); + this.fill(generator, context, value_arg); Ok(()) } diff --git a/nac3core/src/codegen/object/ndarray/factory.rs b/nac3core/src/codegen/object/ndarray/factory.rs index 712c0cff..7f7e18d9 100644 --- a/nac3core/src/codegen/object/ndarray/factory.rs +++ b/nac3core/src/codegen/object/ndarray/factory.rs @@ -1,4 +1,4 @@ -use inkwell::values::BasicValueEnum; +use inkwell::{values::BasicValueEnum, IntPredicate}; use super::NDArrayObject; use crate::{ @@ -122,4 +122,54 @@ impl<'ctx> NDArrayObject<'ctx> { let fill_value = ndarray_one_value(generator, ctx, dtype); NDArrayObject::make_np_full(generator, ctx, dtype, ndims, shape, fill_value) } + + /// Create an ndarray like `np.eye`. + pub fn make_np_eye( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + nrows: Instance<'ctx, Int>, + ncols: Instance<'ctx, Int>, + offset: Instance<'ctx, Int>, + ) -> Self { + let ndzero = ndarray_zero_value(generator, ctx, dtype); + let ndone = ndarray_one_value(generator, ctx, dtype); + + let ndarray = NDArrayObject::alloca_dynamic_shape(generator, ctx, dtype, &[nrows, ncols]); + + // Create data and make the matrix like look np.eye() + ndarray.create_data(generator, ctx); + ndarray + .foreach(generator, ctx, |generator, ctx, _hooks, nditer| { + // NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero + // and this loop would not execute. + + // Load up `row_i` and `col_i` from indices. + let row_i = nditer.get_indices().get_index_const(generator, ctx, 0); + let col_i = nditer.get_indices().get_index_const(generator, ctx, 1); + + let be_one = row_i.add(ctx, offset).compare(ctx, IntPredicate::EQ, col_i); + let value = ctx.builder.build_select(be_one.value, ndone, ndzero, "value").unwrap(); + + let p = nditer.get_pointer(generator, ctx); + ctx.builder.build_store(p, value).unwrap(); + + Ok(()) + }) + .unwrap(); + + ndarray + } + + /// Create an ndarray like `np.identity`. + pub fn make_np_identity( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dtype: Type, + size: Instance<'ctx, Int>, + ) -> Self { + // Convenient implementation + let offset = Int(SizeT).const_0(generator, ctx.ctx); + NDArrayObject::make_np_eye(generator, ctx, dtype, size, size, offset) + } } diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 85643a60..1c301750 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -5,7 +5,7 @@ use inkwell::{ AddressSpace, }; -use super::any::AnyObject; +use super::{any::AnyObject, tuple::TupleObject}; use crate::{ codegen::{ irrt::{ @@ -417,6 +417,8 @@ impl<'ctx> NDArrayObject<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>, ) { + // TODO: It is possible to optimize this by exploiting contiguous strides with memset. + // Probably best to implement in IRRT. self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| { let p = nditer.get_pointer(generator, ctx); ctx.builder.build_store(p, value).unwrap(); @@ -424,6 +426,62 @@ impl<'ctx> NDArrayObject<'ctx> { }) .unwrap(); } + + /// Create the shape tuple of this ndarray like `np.shape()`. + /// + /// The returned integers in the tuple are in int32. + pub fn make_shape_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleObject<'ctx> { + // TODO: Return a tuple of SizeT + + let mut objects = Vec::with_capacity(self.ndims as usize); + + for i in 0..self.ndims { + let dim = self + .instance + .get(generator, ctx, |f| f.shape) + .get_index_const(generator, ctx, i64::try_from(i).unwrap()) + .truncate_or_bit_cast(generator, ctx, Int32); + + objects.push(AnyObject { + ty: ctx.primitives.int32, + value: dim.value.as_basic_value_enum(), + }); + } + + TupleObject::from_objects(generator, ctx, objects) + } + + /// Create the strides tuple of this ndarray like `.strides`. + /// + /// The returned integers in the tuple are in int32. + pub fn make_strides_tuple( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> TupleObject<'ctx> { + // TODO: Return a tuple of SizeT. + + let mut objects = Vec::with_capacity(self.ndims as usize); + + for i in 0..self.ndims { + let dim = self + .instance + .get(generator, ctx, |f| f.strides) + .get_index_const(generator, ctx, i64::try_from(i).unwrap()) + .truncate_or_bit_cast(generator, ctx, Int32); + + objects.push(AnyObject { + ty: ctx.primitives.int32, + value: dim.value.as_basic_value_enum(), + }); + } + + TupleObject::from_objects(generator, ctx, objects) + } } /// A convenience enum for implementing functions that acts on scalars or ndarrays or both. diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 4d88b6ee..c95c6186 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -19,7 +19,9 @@ use crate::{ codegen::{ builtin_fns, classes::{ProxyValue, RangeValue}, + model::*, numpy::*, + object::{any::AnyObject, ndarray::NDArrayObject}, stmt::exn_constructor, }, symbol_resolver::SymbolValue, @@ -512,6 +514,10 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), + PrimDef::FunNpSize | PrimDef::FunNpShape | PrimDef::FunNpStrides => { + self.build_ndarray_property_getter_function(prim) + } + PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -1386,6 +1392,78 @@ impl<'a> BuiltinBuilder<'a> { } } + fn build_ndarray_property_getter_function(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed( + prim, + &[PrimDef::FunNpSize, PrimDef::FunNpShape, PrimDef::FunNpStrides], + ); + + let in_ndarray_ty = self.unifier.get_fresh_var_with_range( + &[self.primitives.ndarray], + Some("T".into()), + None, + ); + + match prim { + PrimDef::FunNpSize => create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + self.primitives.int32, + &[(in_ndarray_ty.ty, "a")], + Box::new(|ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + let ndarray = AnyObject { ty: ndarray_ty, value: ndarray }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); + + let size = + ndarray.size(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32); + Ok(Some(size.value.as_basic_value_enum())) + }), + ), + PrimDef::FunNpShape | PrimDef::FunNpStrides => { + // The function signatures of `np_shape` an `np_size` are the same. + // Mixed together for convenience. + + // The return type is a tuple of variable length depending on the ndims of the input ndarray. + let ret_ty = self.unifier.get_dummy_var().ty; // Handled by special folding + + create_fn_by_codegen( + self.unifier, + &into_var_map([in_ndarray_ty]), + prim.name(), + ret_ty, + &[(in_ndarray_ty.ty, "a")], + Box::new(move |ctx, obj, fun, args, generator| { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let ndarray_ty = fun.0.args[0].ty; + let ndarray = + args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?; + + let ndarray = AnyObject { ty: ndarray_ty, value: ndarray }; + let ndarray = NDArrayObject::from_object(generator, ctx, ndarray); + + let result_tuple = match prim { + PrimDef::FunNpShape => ndarray.make_shape_tuple(generator, ctx), + PrimDef::FunNpStrides => ndarray.make_strides_tuple(generator, ctx), + _ => unreachable!(), + }; + + Ok(Some(result_tuple.value.as_basic_value_enum())) + }), + ) + } + _ => unreachable!(), + } + } + /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; @@ -1888,8 +1966,8 @@ impl<'a> BuiltinBuilder<'a> { self.unifier, &into_var_map([ndarray_ty]), prim.name(), - ndarray_ty.ty, - &[(ndarray_ty.ty, "x")], + self.ndarray_num_ty, + &[(self.ndarray_num_ty, "x")], Box::new(move |ctx, _, fun, args, generator| { let arg_ty = fun.0.args[0].ty; let arg_val = diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 71ee35b0..46e4555f 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -54,6 +54,11 @@ pub enum PrimDef { FunNpEye, FunNpIdentity, + // NumPy ndarray property getters + FunNpSize, + FunNpShape, + FunNpStrides, + // Miscellaneous NumPy & SciPy functions FunNpRound, FunNpFloor, @@ -240,6 +245,11 @@ impl PrimDef { PrimDef::FunNpEye => fun("np_eye", None), PrimDef::FunNpIdentity => fun("np_identity", None), + // NumPy NDArray property getters, + PrimDef::FunNpSize => fun("np_size", None), + PrimDef::FunNpShape => fun("np_shape", None), + PrimDef::FunNpStrides => fun("np_strides", None), + // Miscellaneous NumPy & SciPy functions PrimDef::FunNpRound => fun("np_round", None), PrimDef::FunNpFloor => fun("np_floor", None), diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 53ff774f..78e19dda 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(241)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 2621337c..b4df49c9 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar230]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar230\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index d0769305..65a6a8ac 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(243)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(248)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 5ebdf86c..cfedf1f6 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar229, typevar230]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar229\", \"typevar230\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 502abbd6..e84c450a 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index e5a6cf89..22c953a8 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -3,7 +3,7 @@ use std::{ cmp::max, collections::{HashMap, HashSet}, convert::{From, TryInto}, - iter::once, + iter::{self, once}, sync::Arc, }; @@ -1235,6 +1235,45 @@ impl<'a> Inferencer<'a> { })); } + if ["np_shape".into(), "np_strides".into()].contains(id) && args.len() == 1 { + let ndarray = self.fold_expr(args.remove(0))?; + + let ndims = arraylike_get_ndims(self.unifier, ndarray.custom.unwrap()); + + // Make a tuple of size `ndims` full of int32 (TODO: Make it usize) + let ret_ty = TypeEnum::TTuple { + ty: iter::repeat(self.primitives.int32).take(ndims as usize).collect_vec(), + is_vararg_ctx: false, + }; + let ret_ty = self.unifier.add_ty(ret_ty); + + let func_ty = TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { + name: "a".into(), + default_value: None, + ty: ndarray.custom.unwrap(), + is_vararg: false, + }], + ret: ret_ty, + vars: VarMap::new(), + }); + let func_ty = self.unifier.add_ty(func_ty); + + return Ok(Some(Located { + location, + custom: Some(ret_ty), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(func_ty), + location: func.location, + node: ExprKind::Name { id: *id, ctx: *ctx }, + }), + args: vec![ndarray], + keywords: vec![], + }, + })); + } + if id == &"np_dot".into() { let arg0 = self.fold_expr(args.remove(0))?; let arg1 = self.fold_expr(args.remove(0))?; diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 4f19db95..ca17c3da 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -179,6 +179,11 @@ def patch(module): module.np_identity = np.identity module.np_array = np.array + # NumPy NDArray property getters + module.np_size = np.size + module.np_shape = np.shape + module.np_strides = lambda ndarray: ndarray.strides + # NumPy Math functions module.np_isnan = np.isnan module.np_isinf = np.isinf