diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 7762c7017..05fb9885f 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -871,6 +871,58 @@ impl<'a> Inferencer<'a> { } } } + // array() is a "magic" function call that determines the number of + // dimensions in the result from the nesting of the array argument type, + // to match the host Python NumPy API. + if id == "array".into() { + if args.is_empty() { + return report_error( + "`array()` expects at least one argument (contents in list form)", + func_location, + ); + } + + if args.len() > 2 || !keywords.is_empty() { + return report_error( + "Additional `array()` arguments not yet implemented", + func_location, + ); + // TODO: Implement `dtype=` kwarg. + } + + let list_arg = self.fold_expr(args.remove(0))?; + // TODO: Implement special case for emtpy arrays (e.g. `array([[]]))`) + // to match NumPy. + + let mut num_dims = 0; + let mut elem_type = list_arg.custom.unwrap(); + while let TypeEnum::TList { ty } = &*self.unifier.get_ty(elem_type) { + elem_type = *ty; + num_dims += 1; + } + if num_dims == 0 { + return report_error( + "expected list argument to array(), not xxx", + func_location, + ); + } + + let custom = + Some(self.unifier.add_ty(TypeEnum::TNDArray { ty: elem_type, num_dims })); + return Ok(Located { + location, + custom, + node: ExprKind::Call { + func: Box::new(Located { + custom: None, + location: func.location, + node: ExprKind::Name { id, ctx }, + }), + args: vec![list_arg], + keywords: vec![], + }, + }); + } Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } } else { func diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 74930814a..3e58795c7 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -513,6 +513,14 @@ impl TestEnvironment { [("a", "list[int32]"), ("b", "list[int32]")].iter().cloned().collect(), &[] ; "listcomp test")] +#[test_case( + indoc! {" + a = array([1, 2]) + b = array([[1, 2], [3, 4]]) + "}, + [("a", "ndarray[int32, 1]"), ("b", "ndarray[int32, 2]")].iter().cloned().collect(), + &[] + ; "array test")] #[test_case(indoc! {" a = virtual(Bar(), Bar) b = a.b() @@ -533,6 +541,7 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>, virtuals: &[(&str, &st let mut env = TestEnvironment::new(); let id_to_name = std::mem::take(&mut env.id_to_name); let mut defined_identifiers: HashSet<_> = env.identifier_mapping.keys().cloned().collect(); + defined_identifiers.insert("array".into()); defined_identifiers.insert("virtual".into()); let mut inferencer = env.get_inferencer(); inferencer.defined_identifiers = defined_identifiers.clone(); diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index c028c7246..fdd226aec 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -137,6 +137,10 @@ pub enum TypeEnum { TList { ty: Type, }, + TNDArray { + ty: Type, + num_dims: u8, + }, TObj { obj_id: DefinitionId, fields: Mapping, @@ -156,6 +160,7 @@ impl TypeEnum { TypeEnum::TVar { .. } => "TVar", TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TList { .. } => "TList", + TypeEnum::TNDArray { .. } => "TNDArray", TypeEnum::TObj { .. } => "TObj", TypeEnum::TVirtual { .. } => "TVirtual", TypeEnum::TCall { .. } => "TCall", @@ -387,6 +392,7 @@ impl Unifier { TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, TList { ty } => self.is_concrete(*ty, allowed_typevars), + TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars), TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TObj { params: vars, .. } => { vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) @@ -904,6 +910,13 @@ impl Unifier { TypeEnum::TList { ty } => { format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes)) } + TypeEnum::TNDArray { ty, num_dims } => { + format!( + "ndarray[{}, {}]", + self.internal_stringify(*ty, obj_to_name, var_to_name, notes), + num_dims + ) + } TypeEnum::TVirtual { ty } => { format!( "virtual[{}]",