From c2fdb123972aa8319d16e7764518cdb99578c620 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 4 Jun 2024 17:50:09 +0800 Subject: [PATCH] core/type_inferencer: Add special rule for np_array --- nac3core/src/typecheck/type_inferencer/mod.rs | 80 +++++++++++++++++-- 1 file changed, 75 insertions(+), 5 deletions(-) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index e9b62cb0e..575e5ed25 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -9,7 +9,7 @@ use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ - helper::PRIMITIVE_DEF_IDS, + helper::{arraylike_flatten_element_type, arraylike_get_ndims, PRIMITIVE_DEF_IDS}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelContext, }, @@ -1254,6 +1254,77 @@ impl<'a> Inferencer<'a> { })) } + // 1-argument ndarray n-dimensional creation functions + if id == &"np_array".into() && args.len() == 1 { + let arg0 = self.fold_expr(args.remove(0))?; + + let keywords = keywords.iter() + .map(|v| fold::fold_keyword(self, v.clone())) + .collect::, _>>()?; + let ndmin_kw = keywords.iter() + .find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into())); + + let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap()); + let ndims = if let Some(ndmin_kw) = ndmin_kw { + match &ndmin_kw.node.value.node { + ExprKind::Constant { value, .. } => match value { + ast::Constant::Int(value) => *value as u64, + _ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])), + } + + _ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap()) + } + } else { + arraylike_get_ndims(self.unifier, arg0.custom.unwrap()) + }; + let ndims = self.unifier.get_fresh_literal( + vec![SymbolValue::U64(ndims)], + None, + ); + let ret = make_ndarray_ty( + self.unifier, + self.primitives, + Some(ty), + Some(ndims), + ); + + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "object".into(), + ty: arg0.custom.unwrap(), + default_value: None + }, + FuncArg { + name: "copy".into(), + ty: self.primitives.bool, + default_value: Some(SymbolValue::Bool(true)), + }, + FuncArg { + name: "ndmin".into(), + ty: self.primitives.int32, + default_value: Some(SymbolValue::U32(0)), + }, + ], + ret, + vars: VarMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: ctx.clone() }, + }), + args: vec![arg0], + keywords, + }, + })) + } + Ok(None) } @@ -1264,11 +1335,10 @@ impl<'a> Inferencer<'a> { mut args: Vec>, keywords: Vec>, ) -> Result>, HashSet> { - let func = if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? { + if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? { return Ok(spec_call_func) - } else { - func - }; + } + let func = Box::new(self.fold_expr(func)?); let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; let keywords = keywords