core: add np.transpose and np.reshape functions
This commit is contained in:
parent
a3e6bb2292
commit
00236f48bc
|
@ -2026,3 +2026,394 @@ pub fn gen_ndarray_fill<'ctx>(
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.transpose`.
|
||||
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "ndarray_transpose";
|
||||
let (x1_ty, x1) = x1;
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||
|
||||
// Dimensions are reversed in the transposed array
|
||||
let out = create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&n1,
|
||||
|_, ctx, n| Ok(n.load_ndims(ctx)),
|
||||
|generator, ctx, n, idx| {
|
||||
let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap();
|
||||
let new_idx = ctx
|
||||
.builder
|
||||
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
|
||||
.unwrap();
|
||||
unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) }
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(n_sz, false),
|
||||
|generator, ctx, _, idx| {
|
||||
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
||||
|
||||
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
|
||||
ctx.builder.build_store(rem_idx, idx).unwrap();
|
||||
|
||||
// Incrementally calculate the new index in the transposed array
|
||||
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
|
||||
// The formula used for indexing is:
|
||||
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(n1.load_ndims(ctx), false),
|
||||
|generator, ctx, _, ndim| {
|
||||
let ndim_rev =
|
||||
ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap();
|
||||
let ndim_rev = ctx
|
||||
.builder
|
||||
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
|
||||
.unwrap();
|
||||
let dim = unsafe {
|
||||
n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None)
|
||||
};
|
||||
|
||||
let rem_idx_val =
|
||||
ctx.builder.build_load(rem_idx, "").unwrap().into_int_value();
|
||||
let new_idx_val =
|
||||
ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
|
||||
|
||||
let add_component =
|
||||
ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap();
|
||||
let rem_idx_val =
|
||||
ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap();
|
||||
|
||||
let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap();
|
||||
let new_idx_val =
|
||||
ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap();
|
||||
|
||||
ctx.builder.build_store(rem_idx, rem_idx_val).unwrap();
|
||||
ctx.builder.build_store(new_idx, new_idx_val).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
|
||||
unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) };
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
Ok(out.as_base_value().into())
|
||||
} else {
|
||||
unreachable!(
|
||||
"{FN_NAME}() not supported for '{}'",
|
||||
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
|
||||
///
|
||||
/// * `x1` - `NDArray` to reshape.
|
||||
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
|
||||
/// Just like numpy, the `shape` argument can be:
|
||||
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
|
||||
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
||||
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
||||
/// Note that unlike other generating functions, one of the dimesions in the shape can be negative
|
||||
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
x1: (Type, BasicValueEnum<'ctx>),
|
||||
shape: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "ndarray_reshape";
|
||||
let (x1_ty, x1) = x1;
|
||||
let (_, shape) = shape;
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||
|
||||
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
|
||||
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
|
||||
|
||||
let out = match shape {
|
||||
BasicValueEnum::PointerValue(shape_list_ptr)
|
||||
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
|
||||
{
|
||||
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
|
||||
|
||||
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
|
||||
// Check for -1 in dimensions
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(shape_list.load_size(ctx, None), false),
|
||||
|generator, ctx, _, idx| {
|
||||
let ele =
|
||||
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
||||
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
|
||||
|
||||
gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
ele,
|
||||
llvm_usize.const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap())
|
||||
},
|
||||
|_, ctx| -> Result<Option<IntValue>, String> {
|
||||
let num_neg_value =
|
||||
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
||||
let num_neg_value = ctx
|
||||
.builder
|
||||
.build_int_add(
|
||||
num_neg_value,
|
||||
llvm_usize.const_int(1, false),
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
|
||||
Ok(None)
|
||||
},
|
||||
|_, ctx| {
|
||||
let acc_value =
|
||||
ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||
let acc_value =
|
||||
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
|
||||
ctx.builder.build_store(acc, acc_value).unwrap();
|
||||
Ok(None)
|
||||
},
|
||||
)?;
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
||||
// Generate the output shape by filling -1 with `rem`
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&shape_list,
|
||||
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|
||||
|generator, ctx, shape_list, idx| {
|
||||
let dim =
|
||||
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
||||
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||
|
||||
Ok(gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
dim,
|
||||
llvm_usize.const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap())
|
||||
},
|
||||
|_, _| Ok(Some(rem)),
|
||||
|_, _| Ok(Some(dim)),
|
||||
)?
|
||||
.unwrap()
|
||||
.into_int_value())
|
||||
},
|
||||
)
|
||||
}
|
||||
BasicValueEnum::StructValue(shape_tuple) => {
|
||||
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
||||
|
||||
let ndims = shape_tuple.get_type().count_fields();
|
||||
// Check for -1 in dims
|
||||
for dim_i in 0..ndims {
|
||||
let dim = ctx
|
||||
.builder
|
||||
.build_extract_value(shape_tuple, dim_i, "")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||
|
||||
gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
dim,
|
||||
llvm_usize.const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap())
|
||||
},
|
||||
|_, ctx| -> Result<Option<IntValue>, String> {
|
||||
let num_negs =
|
||||
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
||||
let num_negs = ctx
|
||||
.builder
|
||||
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
|
||||
.unwrap();
|
||||
ctx.builder.build_store(num_neg, num_negs).unwrap();
|
||||
Ok(None)
|
||||
},
|
||||
|_, ctx| {
|
||||
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
|
||||
ctx.builder.build_store(acc, acc_val).unwrap();
|
||||
Ok(None)
|
||||
},
|
||||
)?;
|
||||
}
|
||||
|
||||
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
||||
let mut shape = Vec::with_capacity(ndims as usize);
|
||||
|
||||
// Reconstruct shape filling negatives with rem
|
||||
for dim_i in 0..ndims {
|
||||
let dim = ctx
|
||||
.builder
|
||||
.build_extract_value(shape_tuple, dim_i, "")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||
|
||||
let dim = gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
dim,
|
||||
llvm_usize.const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap())
|
||||
},
|
||||
|_, _| Ok(Some(rem)),
|
||||
|_, _| Ok(Some(dim)),
|
||||
)?
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
shape.push(dim);
|
||||
}
|
||||
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
||||
}
|
||||
BasicValueEnum::IntValue(shape_int) => {
|
||||
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
||||
let shape_int = gen_if_else_expr_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|_, ctx| {
|
||||
Ok(ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
shape_int,
|
||||
llvm_usize.const_zero(),
|
||||
"",
|
||||
)
|
||||
.unwrap())
|
||||
},
|
||||
|_, _| Ok(Some(n_sz)),
|
||||
|_, ctx| {
|
||||
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
|
||||
},
|
||||
)?
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
.unwrap();
|
||||
|
||||
// Only allow one dimension to be negative
|
||||
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
|
||||
.unwrap(),
|
||||
"0:ValueError",
|
||||
"can only specify one unknown dimension",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
// The new shape must be compatible with the old shape
|
||||
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
||||
"0:ValueError",
|
||||
"cannot reshape array of size {} into provided shape of size {}",
|
||||
[Some(n_sz), Some(out_sz), None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
gen_for_callback_incrementing(
|
||||
generator,
|
||||
ctx,
|
||||
None,
|
||||
llvm_usize.const_zero(),
|
||||
(n_sz, false),
|
||||
|generator, ctx, _, idx| {
|
||||
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
||||
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
|
||||
Ok(())
|
||||
},
|
||||
llvm_usize.const_int(1, false),
|
||||
)?;
|
||||
|
||||
Ok(out.as_base_value().into())
|
||||
} else {
|
||||
unreachable!(
|
||||
"{FN_NAME}() not supported for '{}'",
|
||||
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -557,6 +557,10 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
| PrimDef::FunNpHypot
|
||||
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
||||
|
||||
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
|
||||
self.build_np_sp_ndarray_function(prim)
|
||||
}
|
||||
|
||||
PrimDef::FunNpDot
|
||||
| PrimDef::FunNpLinalgMatmul
|
||||
| PrimDef::FunNpLinalgCholesky
|
||||
|
@ -1885,6 +1889,57 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Build np/sp functions that take as input `NDArray` only
|
||||
fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
|
||||
|
||||
match prim {
|
||||
PrimDef::FunNpTranspose => {
|
||||
let ndarray_ty = self.unifier.get_fresh_var_with_range(
|
||||
&[self.ndarray_num_ty],
|
||||
Some("T".into()),
|
||||
None,
|
||||
);
|
||||
create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&into_var_map([ndarray_ty]),
|
||||
prim.name(),
|
||||
ndarray_ty.ty,
|
||||
&[(ndarray_ty.ty, "x")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg_val =
|
||||
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
|
||||
// the `param_ty` for `create_fn_by_codegen`.
|
||||
//
|
||||
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
|
||||
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
||||
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
||||
PrimDef::FunNpReshape => create_fn_by_codegen(
|
||||
self.unifier,
|
||||
&VarMap::new(),
|
||||
prim.name(),
|
||||
self.ndarray_num_ty,
|
||||
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||
}),
|
||||
),
|
||||
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build `np_linalg` and `sp_linalg` functions
|
||||
///
|
||||
/// The input to these functions must be floating point `NDArray`
|
||||
|
|
|
@ -99,6 +99,8 @@ pub enum PrimDef {
|
|||
FunNpLdExp,
|
||||
FunNpHypot,
|
||||
FunNpNextAfter,
|
||||
FunNpTranspose,
|
||||
FunNpReshape,
|
||||
|
||||
// Linalg functions
|
||||
FunNpDot,
|
||||
|
@ -282,6 +284,10 @@ impl PrimDef {
|
|||
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||
PrimDef::FunNpTranspose => fun("np_transpose", None),
|
||||
PrimDef::FunNpReshape => fun("np_reshape", None),
|
||||
|
||||
// Linalg functions
|
||||
PrimDef::FunNpDot => fun("np_dot", None),
|
||||
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
|
||||
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
|
||||
|
|
|
@ -5,7 +5,7 @@ expression: res_vec
|
|||
[
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
|||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
|||
[
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
|||
expression: res_vec
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
|||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n",
|
||||
]
|
||||
|
|
|
@ -1389,7 +1389,45 @@ impl<'a> Inferencer<'a> {
|
|||
},
|
||||
}));
|
||||
}
|
||||
// 2-argument ndarray n-dimensional factory functions
|
||||
if id == &"np_reshape".into() && args.len() == 2 {
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
|
||||
let shape_expr = args.remove(0);
|
||||
let (ndims, shape) =
|
||||
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
|
||||
|
||||
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||
let (elem_ty, _) = unpack_ndarray_var_tys(self.unifier, arg0.custom.unwrap());
|
||||
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(elem_ty), Some(ndims));
|
||||
|
||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
|
||||
FuncArg {
|
||||
name: "shape".into(),
|
||||
ty: shape.custom.unwrap(),
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
ret,
|
||||
vars: VarMap::new(),
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(custom),
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||
}),
|
||||
args: vec![arg0, shape],
|
||||
keywords: vec![],
|
||||
},
|
||||
}));
|
||||
}
|
||||
// 2-argument ndarray n-dimensional creation functions
|
||||
if id == &"np_full".into() && args.len() == 2 {
|
||||
let ExprKind::List { elts, .. } = &args[0].node else {
|
||||
|
|
Loading…
Reference in New Issue