core: support tuple and int32 input for np_empty, np_ones, and more #434

Merged
sb10q merged 1 commits from ndfactory-tuple into master 2024-08-17 17:37:21 +08:00
1 changed files with 2 additions and 2 deletions
Showing only changes of commit 9808923258 - Show all commits

View File

@ -880,7 +880,7 @@ impl<'a> Inferencer<'a> {
// Check `shape_ty` to see if its a list of int32s, a tuple of int32s, or just int32. // Check `shape_ty` to see if its a list of int32s, a tuple of int32s, or just int32.
// Otherwise throw an error as that would mean the user wrote an ill-typed `shape_expr`. // Otherwise throw an error as that would mean the user wrote an ill-typed `shape_expr`.
// //
// Here, we also take the opportunity to deduce `ndims` statically for 2. and 3. // Here, we also take the opportunity to deduce `ndims` statically.
let shape_ty_enum = &*self.unifier.get_ty(shape_ty); let shape_ty_enum = &*self.unifier.get_ty(shape_ty);
let ndims = match shape_ty_enum { let ndims = match shape_ty_enum {
TypeEnum::TList { ty } => { TypeEnum::TList { ty } => {
@ -1292,7 +1292,7 @@ impl<'a> Inferencer<'a> {
{ {
let shape_expr = args.remove(0); let shape_expr = args.remove(0);
let (ndims, shape) = let (ndims, shape) =
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling the `shape` self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None); let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty( let ret = make_ndarray_ty(