Reshape Completed

This commit is contained in:
= 2024-07-31 01:54:36 +08:00
parent 260a2fbb63
commit c71a567a51
3 changed files with 229 additions and 205 deletions

View File

@ -2252,7 +2252,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
llvm_usize.const_int(1, false),
)?;
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_sub(n_sz, acc_val, "").unwrap();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
create_ndarray_dyn_shape(
generator,
ctx,
@ -2286,7 +2286,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
}
BasicValueEnum::StructValue(shape_tuple) => {
let ndims = shape_tuple.get_type().count_fields();
let mut acc_val = ndim_ty.const_int(1, false);
let acc = ctx.builder.build_alloca(ndim_ty, "").unwrap();
ctx.builder.build_store(acc, ndim_ty.const_int(1, false)).unwrap();
for dim_i in 0..ndims {
let dim = ctx
@ -2315,13 +2316,16 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
Ok(None)
},
|_, ctx| {
acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(None)
},
)?;
}
let rem = ctx.builder.build_int_sub(n_sz, acc_val, "").unwrap();
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
let mut shape = Vec::with_capacity(ndims as usize);
for dim_i in 0..ndims {
@ -2397,8 +2401,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
"0:ValueError",
"cannot reshape array of size {} into provided shape",
[Some(n_sz), None, None],
"cannot reshape array of size {} into provided shape of size {}",
[Some(n_sz), Some(out_sz), None],
ctx.current_loc,
);

View File

@ -1913,20 +1913,37 @@ impl<'a> BuiltinBuilder<'a> {
}),
),
PrimDef::FunNpReshape => create_fn_by_codegen(
self.unifier,
&var_map,
prim.name(),
ndarray_ty.ty,
&[(ndarray_ty.ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
),
PrimDef::FunNpReshape => {
// Return type can have differend ndims
let ret_ty =
self.unifier.get_fresh_var_with_range(&[ndarray_type], Some("U".into()), None);
let var_map = var_map
.clone()
.into_iter()
.chain(once((ret_ty.id, ret_ty.ty)))
.chain(once((
self.ndarray_factory_fn_shape_arg_tvar.id,
self.ndarray_factory_fn_shape_arg_tvar.ty,
)))
.collect::<IndexMap<_, _>>();
create_fn_by_codegen(
self.unifier,
&var_map,
prim.name(),
ret_ty.ty,
&[(ndarray_ty.ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val =
args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
)
}
_ => unreachable!(),
}

View File

@ -1444,11 +1444,14 @@ def test_ndarray_transpose():
output_ndarray_float_1(y)
def test_ndarray_reshape():
x: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
y = np_reshape(x, [2, 5])
w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
x: ndarray[float, 4] = np_reshape(w, (1, 2, 1, -1))
y: ndarray[float, 2] = np_reshape(x, [2, -1])
z: ndarray[float, 1] = np_reshape(w, 10)
output_ndarray_float_1(x)
# output_ndarray_float_1(y)
output_ndarray_float_1(w)
output_ndarray_float_2(y)
output_ndarray_float_1(z)
def test_ndarray_dot():
x: ndarray[float, 1] = np_array([5.0, 1.0])
@ -1549,194 +1552,194 @@ def test_ndarray_svd():
def run() -> int32:
# test_ndarray_ctor()
# test_ndarray_empty()
# test_ndarray_zeros()
# test_ndarray_ones()
# test_ndarray_full()
# test_ndarray_eye()
# test_ndarray_array()
# test_ndarray_identity()
# test_ndarray_fill()
# test_ndarray_copy()
test_ndarray_ctor()
test_ndarray_empty()
test_ndarray_zeros()
test_ndarray_ones()
test_ndarray_full()
test_ndarray_eye()
test_ndarray_array()
test_ndarray_identity()
test_ndarray_fill()
test_ndarray_copy()
# test_ndarray_neg_idx()
# test_ndarray_slices()
# test_ndarray_nd_idx()
test_ndarray_neg_idx()
test_ndarray_slices()
test_ndarray_nd_idx()
# test_ndarray_add()
# test_ndarray_add_broadcast()
# test_ndarray_add_broadcast_lhs_scalar()
# test_ndarray_add_broadcast_rhs_scalar()
# test_ndarray_iadd()
# test_ndarray_iadd_broadcast()
# test_ndarray_iadd_broadcast_scalar()
# test_ndarray_sub()
# test_ndarray_sub_broadcast()
# test_ndarray_sub_broadcast_lhs_scalar()
# test_ndarray_sub_broadcast_rhs_scalar()
# test_ndarray_isub()
# test_ndarray_isub_broadcast()
# test_ndarray_isub_broadcast_scalar()
# test_ndarray_mul()
# test_ndarray_mul_broadcast()
# test_ndarray_mul_broadcast_lhs_scalar()
# test_ndarray_mul_broadcast_rhs_scalar()
# test_ndarray_imul()
# test_ndarray_imul_broadcast()
# test_ndarray_imul_broadcast_scalar()
# test_ndarray_truediv()
# test_ndarray_truediv_broadcast()
# test_ndarray_truediv_broadcast_lhs_scalar()
# test_ndarray_truediv_broadcast_rhs_scalar()
# test_ndarray_itruediv()
# test_ndarray_itruediv_broadcast()
# test_ndarray_itruediv_broadcast_scalar()
# test_ndarray_floordiv()
# test_ndarray_floordiv_broadcast()
# test_ndarray_floordiv_broadcast_lhs_scalar()
# test_ndarray_floordiv_broadcast_rhs_scalar()
# test_ndarray_ifloordiv()
# test_ndarray_ifloordiv_broadcast()
# test_ndarray_ifloordiv_broadcast_scalar()
# test_ndarray_mod()
# test_ndarray_mod_broadcast()
# test_ndarray_mod_broadcast_lhs_scalar()
# test_ndarray_mod_broadcast_rhs_scalar()
# test_ndarray_imod()
# test_ndarray_imod_broadcast()
# test_ndarray_imod_broadcast_scalar()
# test_ndarray_pow()
# test_ndarray_pow_broadcast()
# test_ndarray_pow_broadcast_lhs_scalar()
# test_ndarray_pow_broadcast_rhs_scalar()
# test_ndarray_ipow()
# test_ndarray_ipow_broadcast()
# test_ndarray_ipow_broadcast_scalar()
# test_ndarray_matmul()
# test_ndarray_imatmul()
# test_ndarray_pos()
# test_ndarray_neg()
# test_ndarray_inv()
# test_ndarray_eq()
# test_ndarray_eq_broadcast()
# test_ndarray_eq_broadcast_lhs_scalar()
# test_ndarray_eq_broadcast_rhs_scalar()
# test_ndarray_ne()
# test_ndarray_ne_broadcast()
# test_ndarray_ne_broadcast_lhs_scalar()
# test_ndarray_ne_broadcast_rhs_scalar()
# test_ndarray_lt()
# test_ndarray_lt_broadcast()
# test_ndarray_lt_broadcast_lhs_scalar()
# test_ndarray_lt_broadcast_rhs_scalar()
# test_ndarray_lt()
# test_ndarray_le_broadcast()
# test_ndarray_le_broadcast_lhs_scalar()
# test_ndarray_le_broadcast_rhs_scalar()
# test_ndarray_gt()
# test_ndarray_gt_broadcast()
# test_ndarray_gt_broadcast_lhs_scalar()
# test_ndarray_gt_broadcast_rhs_scalar()
# test_ndarray_gt()
# test_ndarray_ge_broadcast()
# test_ndarray_ge_broadcast_lhs_scalar()
# test_ndarray_ge_broadcast_rhs_scalar()
test_ndarray_add()
test_ndarray_add_broadcast()
test_ndarray_add_broadcast_lhs_scalar()
test_ndarray_add_broadcast_rhs_scalar()
test_ndarray_iadd()
test_ndarray_iadd_broadcast()
test_ndarray_iadd_broadcast_scalar()
test_ndarray_sub()
test_ndarray_sub_broadcast()
test_ndarray_sub_broadcast_lhs_scalar()
test_ndarray_sub_broadcast_rhs_scalar()
test_ndarray_isub()
test_ndarray_isub_broadcast()
test_ndarray_isub_broadcast_scalar()
test_ndarray_mul()
test_ndarray_mul_broadcast()
test_ndarray_mul_broadcast_lhs_scalar()
test_ndarray_mul_broadcast_rhs_scalar()
test_ndarray_imul()
test_ndarray_imul_broadcast()
test_ndarray_imul_broadcast_scalar()
test_ndarray_truediv()
test_ndarray_truediv_broadcast()
test_ndarray_truediv_broadcast_lhs_scalar()
test_ndarray_truediv_broadcast_rhs_scalar()
test_ndarray_itruediv()
test_ndarray_itruediv_broadcast()
test_ndarray_itruediv_broadcast_scalar()
test_ndarray_floordiv()
test_ndarray_floordiv_broadcast()
test_ndarray_floordiv_broadcast_lhs_scalar()
test_ndarray_floordiv_broadcast_rhs_scalar()
test_ndarray_ifloordiv()
test_ndarray_ifloordiv_broadcast()
test_ndarray_ifloordiv_broadcast_scalar()
test_ndarray_mod()
test_ndarray_mod_broadcast()
test_ndarray_mod_broadcast_lhs_scalar()
test_ndarray_mod_broadcast_rhs_scalar()
test_ndarray_imod()
test_ndarray_imod_broadcast()
test_ndarray_imod_broadcast_scalar()
test_ndarray_pow()
test_ndarray_pow_broadcast()
test_ndarray_pow_broadcast_lhs_scalar()
test_ndarray_pow_broadcast_rhs_scalar()
test_ndarray_ipow()
test_ndarray_ipow_broadcast()
test_ndarray_ipow_broadcast_scalar()
test_ndarray_matmul()
test_ndarray_imatmul()
test_ndarray_pos()
test_ndarray_neg()
test_ndarray_inv()
test_ndarray_eq()
test_ndarray_eq_broadcast()
test_ndarray_eq_broadcast_lhs_scalar()
test_ndarray_eq_broadcast_rhs_scalar()
test_ndarray_ne()
test_ndarray_ne_broadcast()
test_ndarray_ne_broadcast_lhs_scalar()
test_ndarray_ne_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_lt_broadcast()
test_ndarray_lt_broadcast_lhs_scalar()
test_ndarray_lt_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_le_broadcast()
test_ndarray_le_broadcast_lhs_scalar()
test_ndarray_le_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_gt_broadcast()
test_ndarray_gt_broadcast_lhs_scalar()
test_ndarray_gt_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_ge_broadcast()
test_ndarray_ge_broadcast_lhs_scalar()
test_ndarray_ge_broadcast_rhs_scalar()
# test_ndarray_int32()
# test_ndarray_int64()
# test_ndarray_uint32()
# test_ndarray_uint64()
# test_ndarray_float()
# test_ndarray_bool()
test_ndarray_int32()
test_ndarray_int64()
test_ndarray_uint32()
test_ndarray_uint64()
test_ndarray_float()
test_ndarray_bool()
# test_ndarray_round()
# test_ndarray_floor()
# test_ndarray_min()
# test_ndarray_minimum()
# test_ndarray_minimum_broadcast()
# test_ndarray_minimum_broadcast_lhs_scalar()
# test_ndarray_minimum_broadcast_rhs_scalar()
# test_ndarray_argmin()
# test_ndarray_max()
# test_ndarray_maximum()
# test_ndarray_maximum_broadcast()
# test_ndarray_maximum_broadcast_lhs_scalar()
# test_ndarray_maximum_broadcast_rhs_scalar()
# test_ndarray_argmax()
# test_ndarray_abs()
# test_ndarray_isnan()
# test_ndarray_isinf()
test_ndarray_round()
test_ndarray_floor()
test_ndarray_min()
test_ndarray_minimum()
test_ndarray_minimum_broadcast()
test_ndarray_minimum_broadcast_lhs_scalar()
test_ndarray_minimum_broadcast_rhs_scalar()
test_ndarray_argmin()
test_ndarray_max()
test_ndarray_maximum()
test_ndarray_maximum_broadcast()
test_ndarray_maximum_broadcast_lhs_scalar()
test_ndarray_maximum_broadcast_rhs_scalar()
test_ndarray_argmax()
test_ndarray_abs()
test_ndarray_isnan()
test_ndarray_isinf()
# test_ndarray_sin()
# test_ndarray_cos()
# test_ndarray_exp()
# test_ndarray_exp2()
# test_ndarray_log()
# test_ndarray_log10()
# test_ndarray_log2()
# test_ndarray_fabs()
# test_ndarray_sqrt()
# test_ndarray_rint()
# test_ndarray_tan()
# test_ndarray_arcsin()
# test_ndarray_arccos()
# test_ndarray_arctan()
# test_ndarray_sinh()
# test_ndarray_cosh()
# test_ndarray_tanh()
# test_ndarray_arcsinh()
# test_ndarray_arccosh()
# test_ndarray_arctanh()
# test_ndarray_expm1()
# test_ndarray_cbrt()
test_ndarray_sin()
test_ndarray_cos()
test_ndarray_exp()
test_ndarray_exp2()
test_ndarray_log()
test_ndarray_log10()
test_ndarray_log2()
test_ndarray_fabs()
test_ndarray_sqrt()
test_ndarray_rint()
test_ndarray_tan()
test_ndarray_arcsin()
test_ndarray_arccos()
test_ndarray_arctan()
test_ndarray_sinh()
test_ndarray_cosh()
test_ndarray_tanh()
test_ndarray_arcsinh()
test_ndarray_arccosh()
test_ndarray_arctanh()
test_ndarray_expm1()
test_ndarray_cbrt()
# test_ndarray_erf()
# test_ndarray_erfc()
# test_ndarray_gamma()
# test_ndarray_gammaln()
# test_ndarray_j0()
# test_ndarray_j1()
test_ndarray_erf()
test_ndarray_erfc()
test_ndarray_gamma()
test_ndarray_gammaln()
test_ndarray_j0()
test_ndarray_j1()
# test_ndarray_arctan2()
# test_ndarray_arctan2_broadcast()
# test_ndarray_arctan2_broadcast_lhs_scalar()
# test_ndarray_arctan2_broadcast_rhs_scalar()
# test_ndarray_copysign()
# test_ndarray_copysign_broadcast()
# test_ndarray_copysign_broadcast_lhs_scalar()
# test_ndarray_copysign_broadcast_rhs_scalar()
# test_ndarray_fmax()
# test_ndarray_fmax_broadcast()
# test_ndarray_fmax_broadcast_lhs_scalar()
# test_ndarray_fmax_broadcast_rhs_scalar()
# test_ndarray_fmin()
# test_ndarray_fmin_broadcast()
# test_ndarray_fmin_broadcast_lhs_scalar()
# test_ndarray_fmin_broadcast_rhs_scalar()
# test_ndarray_ldexp()
# test_ndarray_ldexp_broadcast()
# test_ndarray_ldexp_broadcast_lhs_scalar()
# test_ndarray_ldexp_broadcast_rhs_scalar()
# test_ndarray_hypot()
# test_ndarray_hypot_broadcast()
# test_ndarray_hypot_broadcast_lhs_scalar()
# test_ndarray_hypot_broadcast_rhs_scalar()
# test_ndarray_nextafter()
# test_ndarray_nextafter_broadcast()
# test_ndarray_nextafter_broadcast_lhs_scalar()
# test_ndarray_nextafter_broadcast_rhs_scalar()
# test_ndarray_transpose()
test_ndarray_arctan2()
test_ndarray_arctan2_broadcast()
test_ndarray_arctan2_broadcast_lhs_scalar()
test_ndarray_arctan2_broadcast_rhs_scalar()
test_ndarray_copysign()
test_ndarray_copysign_broadcast()
test_ndarray_copysign_broadcast_lhs_scalar()
test_ndarray_copysign_broadcast_rhs_scalar()
test_ndarray_fmax()
test_ndarray_fmax_broadcast()
test_ndarray_fmax_broadcast_lhs_scalar()
test_ndarray_fmax_broadcast_rhs_scalar()
test_ndarray_fmin()
test_ndarray_fmin_broadcast()
test_ndarray_fmin_broadcast_lhs_scalar()
test_ndarray_fmin_broadcast_rhs_scalar()
test_ndarray_ldexp()
test_ndarray_ldexp_broadcast()
test_ndarray_ldexp_broadcast_lhs_scalar()
test_ndarray_ldexp_broadcast_rhs_scalar()
test_ndarray_hypot()
test_ndarray_hypot_broadcast()
test_ndarray_hypot_broadcast_lhs_scalar()
test_ndarray_hypot_broadcast_rhs_scalar()
test_ndarray_nextafter()
test_ndarray_nextafter_broadcast()
test_ndarray_nextafter_broadcast_lhs_scalar()
test_ndarray_nextafter_broadcast_rhs_scalar()
test_ndarray_transpose()
test_ndarray_reshape()
# test_ndarray_dot()
# test_ndarray_linalg_matmul()
# test_ndarray_cholesky()
# test_ndarray_qr()
# test_ndarray_svd()
# test_ndarray_linalg_inv()
# test_ndarray_pinv()
# test_ndarray_lu()
# test_ndarray_schur()
# test_ndarray_hessenberg()
test_ndarray_dot()
test_ndarray_linalg_matmul()
test_ndarray_cholesky()
test_ndarray_qr()
test_ndarray_svd()
test_ndarray_linalg_inv()
test_ndarray_pinv()
test_ndarray_lu()
test_ndarray_schur()
test_ndarray_hessenberg()
return 0