forked from M-Labs/nac3
core: add np.transpose and np.reshape functions
This commit is contained in:
parent
a3e6bb2292
commit
00236f48bc
|
@ -2026,3 +2026,394 @@ pub fn gen_ndarray_fill<'ctx>(
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.transpose`.
|
||||||
|
pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "ndarray_transpose";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
// Dimensions are reversed in the transposed array
|
||||||
|
let out = create_ndarray_dyn_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&n1,
|
||||||
|
|_, ctx, n| Ok(n.load_ndims(ctx)),
|
||||||
|
|generator, ctx, n, idx| {
|
||||||
|
let new_idx = ctx.builder.build_int_sub(n.load_ndims(ctx), idx, "").unwrap();
|
||||||
|
let new_idx = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_sub(new_idx, new_idx.get_type().const_int(1, false), "")
|
||||||
|
.unwrap();
|
||||||
|
unsafe { Ok(n.dim_sizes().get_typed_unchecked(ctx, generator, &new_idx, None)) }
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(n_sz, false),
|
||||||
|
|generator, ctx, _, idx| {
|
||||||
|
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
|
||||||
|
let new_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
let rem_idx = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
ctx.builder.build_store(new_idx, llvm_usize.const_zero()).unwrap();
|
||||||
|
ctx.builder.build_store(rem_idx, idx).unwrap();
|
||||||
|
|
||||||
|
// Incrementally calculate the new index in the transposed array
|
||||||
|
// For each index, we first decompose it into the n-dims and use those to reconstruct the new index
|
||||||
|
// The formula used for indexing is:
|
||||||
|
// idx = dim_n * ( ... (dim2 * (dim0 * dim1) + dim1) + dim2 ... ) + dim_n
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(n1.load_ndims(ctx), false),
|
||||||
|
|generator, ctx, _, ndim| {
|
||||||
|
let ndim_rev =
|
||||||
|
ctx.builder.build_int_sub(n1.load_ndims(ctx), ndim, "").unwrap();
|
||||||
|
let ndim_rev = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_sub(ndim_rev, llvm_usize.const_int(1, false), "")
|
||||||
|
.unwrap();
|
||||||
|
let dim = unsafe {
|
||||||
|
n1.dim_sizes().get_typed_unchecked(ctx, generator, &ndim_rev, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
let rem_idx_val =
|
||||||
|
ctx.builder.build_load(rem_idx, "").unwrap().into_int_value();
|
||||||
|
let new_idx_val =
|
||||||
|
ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
|
||||||
|
|
||||||
|
let add_component =
|
||||||
|
ctx.builder.build_int_unsigned_rem(rem_idx_val, dim, "").unwrap();
|
||||||
|
let rem_idx_val =
|
||||||
|
ctx.builder.build_int_unsigned_div(rem_idx_val, dim, "").unwrap();
|
||||||
|
|
||||||
|
let new_idx_val = ctx.builder.build_int_mul(new_idx_val, dim, "").unwrap();
|
||||||
|
let new_idx_val =
|
||||||
|
ctx.builder.build_int_add(new_idx_val, add_component, "").unwrap();
|
||||||
|
|
||||||
|
ctx.builder.build_store(rem_idx, rem_idx_val).unwrap();
|
||||||
|
ctx.builder.build_store(new_idx, new_idx_val).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let new_idx_val = ctx.builder.build_load(new_idx, "").unwrap().into_int_value();
|
||||||
|
unsafe { out.data().set_unchecked(ctx, generator, &new_idx_val, elem) };
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(out.as_base_value().into())
|
||||||
|
} else {
|
||||||
|
unreachable!(
|
||||||
|
"{FN_NAME}() not supported for '{}'",
|
||||||
|
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for generating the implementation for `ndarray.reshape`.
|
||||||
|
///
|
||||||
|
/// * `x1` - `NDArray` to reshape.
|
||||||
|
/// * `shape` - The `shape` parameter used to construct the new `NDArray`.
|
||||||
|
/// Just like numpy, the `shape` argument can be:
|
||||||
|
/// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])`
|
||||||
|
/// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
||||||
|
/// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
||||||
|
/// Note that unlike other generating functions, one of the dimesions in the shape can be negative
|
||||||
|
pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
shape: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "ndarray_reshape";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let (_, shape) = shape;
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
if let BasicValueEnum::PointerValue(n1) = x1 {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
|
ctx.builder.build_store(acc, llvm_usize.const_int(1, false)).unwrap();
|
||||||
|
ctx.builder.build_store(num_neg, llvm_usize.const_zero()).unwrap();
|
||||||
|
|
||||||
|
let out = match shape {
|
||||||
|
BasicValueEnum::PointerValue(shape_list_ptr)
|
||||||
|
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
|
||||||
|
{
|
||||||
|
// 1. A list of ints; e.g., `np.reshape(arr, [int64(600), int64(800, -1])`
|
||||||
|
|
||||||
|
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
|
||||||
|
// Check for -1 in dimensions
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(shape_list.load_size(ctx, None), false),
|
||||||
|
|generator, ctx, _, idx| {
|
||||||
|
let ele =
|
||||||
|
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
||||||
|
let ele = ctx.builder.build_int_s_extend(ele, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
ele,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|_, ctx| -> Result<Option<IntValue>, String> {
|
||||||
|
let num_neg_value =
|
||||||
|
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
||||||
|
let num_neg_value = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_add(
|
||||||
|
num_neg_value,
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
ctx.builder.build_store(num_neg, num_neg_value).unwrap();
|
||||||
|
Ok(None)
|
||||||
|
},
|
||||||
|
|_, ctx| {
|
||||||
|
let acc_value =
|
||||||
|
ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||||
|
let acc_value =
|
||||||
|
ctx.builder.build_int_mul(acc_value, ele, "").unwrap();
|
||||||
|
ctx.builder.build_store(acc, acc_value).unwrap();
|
||||||
|
Ok(None)
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||||
|
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
||||||
|
// Generate the output shape by filling -1 with `rem`
|
||||||
|
create_ndarray_dyn_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&shape_list,
|
||||||
|
|_, ctx, _| Ok(shape_list.load_size(ctx, None)),
|
||||||
|
|generator, ctx, shape_list, idx| {
|
||||||
|
let dim =
|
||||||
|
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
|
||||||
|
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
Ok(gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
dim,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|_, _| Ok(Some(rem)),
|
||||||
|
|_, _| Ok(Some(dim)),
|
||||||
|
)?
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
BasicValueEnum::StructValue(shape_tuple) => {
|
||||||
|
// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))`
|
||||||
|
|
||||||
|
let ndims = shape_tuple.get_type().count_fields();
|
||||||
|
// Check for -1 in dims
|
||||||
|
for dim_i in 0..ndims {
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_extract_value(shape_tuple, dim_i, "")
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
dim,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|_, ctx| -> Result<Option<IntValue>, String> {
|
||||||
|
let num_negs =
|
||||||
|
ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
||||||
|
let num_negs = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_add(num_negs, llvm_usize.const_int(1, false), "")
|
||||||
|
.unwrap();
|
||||||
|
ctx.builder.build_store(num_neg, num_negs).unwrap();
|
||||||
|
Ok(None)
|
||||||
|
},
|
||||||
|
|_, ctx| {
|
||||||
|
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||||
|
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
|
||||||
|
ctx.builder.build_store(acc, acc_val).unwrap();
|
||||||
|
Ok(None)
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
|
||||||
|
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
|
||||||
|
let mut shape = Vec::with_capacity(ndims as usize);
|
||||||
|
|
||||||
|
// Reconstruct shape filling negatives with rem
|
||||||
|
for dim_i in 0..ndims {
|
||||||
|
let dim = ctx
|
||||||
|
.builder
|
||||||
|
.build_extract_value(shape_tuple, dim_i, "")
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
let dim = ctx.builder.build_int_s_extend(dim, llvm_usize, "").unwrap();
|
||||||
|
|
||||||
|
let dim = gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
dim,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|_, _| Ok(Some(rem)),
|
||||||
|
|_, _| Ok(Some(dim)),
|
||||||
|
)?
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
shape.push(dim);
|
||||||
|
}
|
||||||
|
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
|
||||||
|
}
|
||||||
|
BasicValueEnum::IntValue(shape_int) => {
|
||||||
|
// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)`
|
||||||
|
let shape_int = gen_if_else_expr_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::SLT,
|
||||||
|
shape_int,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.unwrap())
|
||||||
|
},
|
||||||
|
|_, _| Ok(Some(n_sz)),
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(Some(ctx.builder.build_int_s_extend(shape_int, llvm_usize, "").unwrap()))
|
||||||
|
},
|
||||||
|
)?
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value();
|
||||||
|
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Only allow one dimension to be negative
|
||||||
|
let num_negs = ctx.builder.build_load(num_neg, "").unwrap().into_int_value();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder
|
||||||
|
.build_int_compare(IntPredicate::ULT, num_negs, llvm_usize.const_int(2, false), "")
|
||||||
|
.unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"can only specify one unknown dimension",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
// The new shape must be compatible with the old shape
|
||||||
|
let out_sz = call_ndarray_calc_size(generator, ctx, &out.dim_sizes(), (None, None));
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"cannot reshape array of size {} into provided shape of size {}",
|
||||||
|
[Some(n_sz), Some(out_sz), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(n_sz, false),
|
||||||
|
|generator, ctx, _, idx| {
|
||||||
|
let elem = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
unsafe { out.data().set_unchecked(ctx, generator, &idx, elem) };
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(out.as_base_value().into())
|
||||||
|
} else {
|
||||||
|
unreachable!(
|
||||||
|
"{FN_NAME}() not supported for '{}'",
|
||||||
|
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -557,6 +557,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
| PrimDef::FunNpHypot
|
| PrimDef::FunNpHypot
|
||||||
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
| PrimDef::FunNpNextAfter => self.build_np_2ary_function(prim),
|
||||||
|
|
||||||
|
PrimDef::FunNpTranspose | PrimDef::FunNpReshape => {
|
||||||
|
self.build_np_sp_ndarray_function(prim)
|
||||||
|
}
|
||||||
|
|
||||||
PrimDef::FunNpDot
|
PrimDef::FunNpDot
|
||||||
| PrimDef::FunNpLinalgMatmul
|
| PrimDef::FunNpLinalgMatmul
|
||||||
| PrimDef::FunNpLinalgCholesky
|
| PrimDef::FunNpLinalgCholesky
|
||||||
|
@ -1885,6 +1889,57 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build np/sp functions that take as input `NDArray` only
|
||||||
|
fn build_np_sp_ndarray_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
debug_assert_prim_is_allowed(prim, &[PrimDef::FunNpTranspose, PrimDef::FunNpReshape]);
|
||||||
|
|
||||||
|
match prim {
|
||||||
|
PrimDef::FunNpTranspose => {
|
||||||
|
let ndarray_ty = self.unifier.get_fresh_var_with_range(
|
||||||
|
&[self.ndarray_num_ty],
|
||||||
|
Some("T".into()),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&into_var_map([ndarray_ty]),
|
||||||
|
prim.name(),
|
||||||
|
ndarray_ty.ty,
|
||||||
|
&[(ndarray_ty.ty, "x")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let arg_ty = fun.0.args[0].ty;
|
||||||
|
let arg_val =
|
||||||
|
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
|
Ok(Some(ndarray_transpose(generator, ctx, (arg_ty, arg_val))?))
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
|
||||||
|
// the `param_ty` for `create_fn_by_codegen`.
|
||||||
|
//
|
||||||
|
// Similar to `build_ndarray_from_shape_factory_function` we delegate the responsibility of typechecking
|
||||||
|
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
|
||||||
|
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
|
||||||
|
PrimDef::FunNpReshape => create_fn_by_codegen(
|
||||||
|
self.unifier,
|
||||||
|
&VarMap::new(),
|
||||||
|
prim.name(),
|
||||||
|
self.ndarray_num_ty,
|
||||||
|
&[(self.ndarray_num_ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
|
||||||
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
|
let x1_ty = fun.0.args[0].ty;
|
||||||
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
let x2_ty = fun.0.args[1].ty;
|
||||||
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Build `np_linalg` and `sp_linalg` functions
|
/// Build `np_linalg` and `sp_linalg` functions
|
||||||
///
|
///
|
||||||
/// The input to these functions must be floating point `NDArray`
|
/// The input to these functions must be floating point `NDArray`
|
||||||
|
|
|
@ -99,6 +99,8 @@ pub enum PrimDef {
|
||||||
FunNpLdExp,
|
FunNpLdExp,
|
||||||
FunNpHypot,
|
FunNpHypot,
|
||||||
FunNpNextAfter,
|
FunNpNextAfter,
|
||||||
|
FunNpTranspose,
|
||||||
|
FunNpReshape,
|
||||||
|
|
||||||
// Linalg functions
|
// Linalg functions
|
||||||
FunNpDot,
|
FunNpDot,
|
||||||
|
@ -282,6 +284,10 @@ impl PrimDef {
|
||||||
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
PrimDef::FunNpLdExp => fun("np_ldexp", None),
|
||||||
PrimDef::FunNpHypot => fun("np_hypot", None),
|
PrimDef::FunNpHypot => fun("np_hypot", None),
|
||||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||||
|
PrimDef::FunNpTranspose => fun("np_transpose", None),
|
||||||
|
PrimDef::FunNpReshape => fun("np_reshape", None),
|
||||||
|
|
||||||
|
// Linalg functions
|
||||||
PrimDef::FunNpDot => fun("np_dot", None),
|
PrimDef::FunNpDot => fun("np_dot", None),
|
||||||
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
|
PrimDef::FunNpLinalgMatmul => fun("np_linalg_matmul", None),
|
||||||
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
|
PrimDef::FunNpLinalgCholesky => fun("np_linalg_cholesky", None),
|
||||||
|
|
|
@ -5,7 +5,7 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar235]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar235\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar234, typevar235]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar234\", \"typevar235\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(254)]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(262)]\n}\n",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1389,7 +1389,45 @@ impl<'a> Inferencer<'a> {
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
// 2-argument ndarray n-dimensional factory functions
|
||||||
|
if id == &"np_reshape".into() && args.len() == 2 {
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
|
||||||
|
let shape_expr = args.remove(0);
|
||||||
|
let (ndims, shape) =
|
||||||
|
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling for `shape`
|
||||||
|
|
||||||
|
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(self.unifier, arg0.custom.unwrap());
|
||||||
|
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(elem_ty), Some(ndims));
|
||||||
|
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
|
||||||
|
FuncArg {
|
||||||
|
name: "shape".into(),
|
||||||
|
ty: shape.custom.unwrap(),
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||||
|
}),
|
||||||
|
args: vec![arg0, shape],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
// 2-argument ndarray n-dimensional creation functions
|
// 2-argument ndarray n-dimensional creation functions
|
||||||
if id == &"np_full".into() && args.len() == 2 {
|
if id == &"np_full".into() && args.len() == 2 {
|
||||||
let ExprKind::List { elts, .. } = &args[0].node else {
|
let ExprKind::List { elts, .. } = &args[0].node else {
|
||||||
|
|
Loading…
Reference in New Issue