diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index e212ac8..fdef020 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,6 +1,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; +use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{Mapping, VarMap}; use nac3parser::ast::{Constant, Location}; @@ -691,3 +692,35 @@ pub fn parse_parameter_default_value( ])) } } + +/// Obtains the element type of an array-like type. +pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { + match &*unifier.get_ty(ty) { + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => + unpack_ndarray_var_tys(unifier, ty).0, + + TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty), + _ => ty + } +} + +/// Obtains the number of dimensions of an array-like type. +pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { + match &*unifier.get_ty(ty) { + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + let ndims = unpack_ndarray_var_tys(unifier, ty).1; + let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else { + panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims)) + }; + + if values.len() > 1 { + todo!("Getting num of dimensions for ndarray with more than one ndim bound is unimplemented") + } + + u64::try_from(values[0].clone()).unwrap() + } + + TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1, + _ => 0 + } +} diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 88b27f0..eb512ca 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -9,7 +9,7 @@ use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ - helper::PRIMITIVE_DEF_IDS, + helper::{arraylike_flatten_element_type, arraylike_get_ndims, PRIMITIVE_DEF_IDS}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelContext, }, @@ -1254,6 +1254,61 @@ impl<'a> Inferencer<'a> { })) } + // 3-argument ndarray n-dimensional creation functions + if id == &"np_array".into() && args.len() == 3 { + let arg0 = self.fold_expr(args.remove(0))?; + let arg1 = self.fold_expr(args.remove(0))?; + let arg2 = self.fold_expr(args.remove(0))?; + + let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap()); + let ndims = arraylike_get_ndims(self.unifier, arg0.custom.unwrap()); + let ndims = self.unifier.get_fresh_literal( + vec![SymbolValue::U64(ndims)], + None, + ); + let ret = make_ndarray_ty( + self.unifier, + self.primitives, + Some(ty), + Some(ndims), + ); + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "object".into(), + ty: arg0.custom.unwrap(), + default_value: None + }, + FuncArg { + name: "copy".into(), + ty: arg1.custom.unwrap(), + default_value: Some(SymbolValue::Bool(true)), + }, + FuncArg { + name: "ndmin".into(), + ty: arg2.custom.unwrap(), + default_value: Some(SymbolValue::U32(0)), + }, + ], + ret, + vars: VarMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: ctx.clone() }, + }), + args: vec![arg0, arg1, arg2], + keywords: vec![], + }, + })) + } + Ok(None) }