From 5640a793e2fcf98151848cfee9fa831e9d11e06c Mon Sep 17 00:00:00 2001 From: lyken Date: Fri, 30 Aug 2024 17:11:38 +0800 Subject: [PATCH] core: allow np_full to take tuple shapes --- nac3core/src/typecheck/type_inferencer/mod.rs | 31 +++++++------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index a5b8cd496..d27f746e8 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1550,36 +1550,29 @@ impl<'a> Inferencer<'a> { } // 2-argument ndarray n-dimensional creation functions if id == &"np_full".into() && args.len() == 2 { - let ExprKind::List { elts, .. } = &args[0].node else { - return report_error( - format!( - "Expected List literal for first argument of {id}, got {}", - args[0].node.name() - ) - .as_str(), - args[0].location, - ); - }; + // Parse arguments + let shape_expr = args.remove(0); + let (ndims, shape) = + self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape` - let ndims = elts.len() as u64; + let fill_value = self.fold_expr(args.remove(0))?; - let arg0 = self.fold_expr(args.remove(0))?; - let arg1 = self.fold_expr(args.remove(0))?; - - let ty = arg1.custom.unwrap(); + // Build the return type + let dtype = fill_value.custom.unwrap(); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); - let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); + let ret = make_ndarray_ty(self.unifier, self.primitives, Some(dtype), Some(ndims)); + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ FuncArg { name: "shape".into(), - ty: arg0.custom.unwrap(), + ty: shape.custom.unwrap(), default_value: None, is_vararg: false, }, FuncArg { name: "fill_value".into(), - ty: arg1.custom.unwrap(), + ty: fill_value.custom.unwrap(), default_value: None, is_vararg: false, }, @@ -1597,7 +1590,7 @@ impl<'a> Inferencer<'a> { location: func.location, node: ExprKind::Name { id: *id, ctx: *ctx }, }), - args: vec![arg0, arg1], + args: vec![shape, fill_value], keywords: vec![], }, }));