forked from M-Labs/nac3
1
0
Fork 0

core: allow np_full to take tuple shapes

This commit is contained in:
lyken 2024-08-30 17:11:38 +08:00 committed by David Mak
parent abbaa506ad
commit 5640a793e2
1 changed files with 12 additions and 19 deletions

View File

@ -1550,36 +1550,29 @@ impl<'a> Inferencer<'a> {
} }
// 2-argument ndarray n-dimensional creation functions // 2-argument ndarray n-dimensional creation functions
if id == &"np_full".into() && args.len() == 2 { if id == &"np_full".into() && args.len() == 2 {
let ExprKind::List { elts, .. } = &args[0].node else { // Parse arguments
return report_error( let shape_expr = args.remove(0);
format!( let (ndims, shape) =
"Expected List literal for first argument of {id}, got {}", self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
args[0].node.name()
)
.as_str(),
args[0].location,
);
};
let ndims = elts.len() as u64; let fill_value = self.fold_expr(args.remove(0))?;
let arg0 = self.fold_expr(args.remove(0))?; // Build the return type
let arg1 = self.fold_expr(args.remove(0))?; let dtype = fill_value.custom.unwrap();
let ty = arg1.custom.unwrap();
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)); let ret = make_ndarray_ty(self.unifier, self.primitives, Some(dtype), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![
FuncArg { FuncArg {
name: "shape".into(), name: "shape".into(),
ty: arg0.custom.unwrap(), ty: shape.custom.unwrap(),
default_value: None, default_value: None,
is_vararg: false, is_vararg: false,
}, },
FuncArg { FuncArg {
name: "fill_value".into(), name: "fill_value".into(),
ty: arg1.custom.unwrap(), ty: fill_value.custom.unwrap(),
default_value: None, default_value: None,
is_vararg: false, is_vararg: false,
}, },
@ -1597,7 +1590,7 @@ impl<'a> Inferencer<'a> {
location: func.location, location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx }, node: ExprKind::Name { id: *id, ctx: *ctx },
}), }),
args: vec![arg0, arg1], args: vec![shape, fill_value],
keywords: vec![], keywords: vec![],
}, },
})); }));