Reshape Completed

This commit is contained in:
= 2024-07-31 01:54:36 +08:00
parent 260a2fbb63
commit c71a567a51
3 changed files with 229 additions and 205 deletions

View File

@ -2252,7 +2252,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
)?; )?;
let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_sub(n_sz, acc_val, "").unwrap(); let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
create_ndarray_dyn_shape( create_ndarray_dyn_shape(
generator, generator,
ctx, ctx,
@ -2286,7 +2286,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
} }
BasicValueEnum::StructValue(shape_tuple) => { BasicValueEnum::StructValue(shape_tuple) => {
let ndims = shape_tuple.get_type().count_fields(); let ndims = shape_tuple.get_type().count_fields();
let mut acc_val = ndim_ty.const_int(1, false); let acc = ctx.builder.build_alloca(ndim_ty, "").unwrap();
ctx.builder.build_store(acc, ndim_ty.const_int(1, false)).unwrap();
for dim_i in 0..ndims { for dim_i in 0..ndims {
let dim = ctx let dim = ctx
@ -2315,13 +2316,16 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
Ok(None) Ok(None)
}, },
|_, ctx| { |_, ctx| {
acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap(); let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap();
ctx.builder.build_store(acc, acc_val).unwrap();
Ok(None) Ok(None)
}, },
)?; )?;
} }
let rem = ctx.builder.build_int_sub(n_sz, acc_val, "").unwrap(); let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value();
let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap();
let mut shape = Vec::with_capacity(ndims as usize); let mut shape = Vec::with_capacity(ndims as usize);
for dim_i in 0..ndims { for dim_i in 0..ndims {
@ -2397,8 +2401,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
generator, generator,
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
"0:ValueError", "0:ValueError",
"cannot reshape array of size {} into provided shape", "cannot reshape array of size {} into provided shape of size {}",
[Some(n_sz), None, None], [Some(n_sz), Some(out_sz), None],
ctx.current_loc, ctx.current_loc,
); );

View File

@ -1913,20 +1913,37 @@ impl<'a> BuiltinBuilder<'a> {
}), }),
), ),
PrimDef::FunNpReshape => create_fn_by_codegen( PrimDef::FunNpReshape => {
self.unifier, // Return type can have differend ndims
&var_map, let ret_ty =
prim.name(), self.unifier.get_fresh_var_with_range(&[ndarray_type], Some("U".into()), None);
ndarray_ty.ty, let var_map = var_map
&[(ndarray_ty.ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], .clone()
Box::new(move |ctx, _, fun, args, generator| { .into_iter()
let x1_ty = fun.0.args[0].ty; .chain(once((ret_ty.id, ret_ty.ty)))
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; .chain(once((
let x2_ty = fun.0.args[1].ty; self.ndarray_factory_fn_shape_arg_tvar.id,
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; self.ndarray_factory_fn_shape_arg_tvar.ty,
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) )))
}), .collect::<IndexMap<_, _>>();
),
create_fn_by_codegen(
self.unifier,
&var_map,
prim.name(),
ret_ty.ty,
&[(ndarray_ty.ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val =
args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val =
args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
}),
)
}
_ => unreachable!(), _ => unreachable!(),
} }

View File

@ -1444,11 +1444,14 @@ def test_ndarray_transpose():
output_ndarray_float_1(y) output_ndarray_float_1(y)
def test_ndarray_reshape(): def test_ndarray_reshape():
x: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
y = np_reshape(x, [2, 5]) x: ndarray[float, 4] = np_reshape(w, (1, 2, 1, -1))
y: ndarray[float, 2] = np_reshape(x, [2, -1])
z: ndarray[float, 1] = np_reshape(w, 10)
output_ndarray_float_1(x) output_ndarray_float_1(w)
# output_ndarray_float_1(y) output_ndarray_float_2(y)
output_ndarray_float_1(z)
def test_ndarray_dot(): def test_ndarray_dot():
x: ndarray[float, 1] = np_array([5.0, 1.0]) x: ndarray[float, 1] = np_array([5.0, 1.0])
@ -1549,194 +1552,194 @@ def test_ndarray_svd():
def run() -> int32: def run() -> int32:
# test_ndarray_ctor() test_ndarray_ctor()
# test_ndarray_empty() test_ndarray_empty()
# test_ndarray_zeros() test_ndarray_zeros()
# test_ndarray_ones() test_ndarray_ones()
# test_ndarray_full() test_ndarray_full()
# test_ndarray_eye() test_ndarray_eye()
# test_ndarray_array() test_ndarray_array()
# test_ndarray_identity() test_ndarray_identity()
# test_ndarray_fill() test_ndarray_fill()
# test_ndarray_copy() test_ndarray_copy()
# test_ndarray_neg_idx() test_ndarray_neg_idx()
# test_ndarray_slices() test_ndarray_slices()
# test_ndarray_nd_idx() test_ndarray_nd_idx()
# test_ndarray_add() test_ndarray_add()
# test_ndarray_add_broadcast() test_ndarray_add_broadcast()
# test_ndarray_add_broadcast_lhs_scalar() test_ndarray_add_broadcast_lhs_scalar()
# test_ndarray_add_broadcast_rhs_scalar() test_ndarray_add_broadcast_rhs_scalar()
# test_ndarray_iadd() test_ndarray_iadd()
# test_ndarray_iadd_broadcast() test_ndarray_iadd_broadcast()
# test_ndarray_iadd_broadcast_scalar() test_ndarray_iadd_broadcast_scalar()
# test_ndarray_sub() test_ndarray_sub()
# test_ndarray_sub_broadcast() test_ndarray_sub_broadcast()
# test_ndarray_sub_broadcast_lhs_scalar() test_ndarray_sub_broadcast_lhs_scalar()
# test_ndarray_sub_broadcast_rhs_scalar() test_ndarray_sub_broadcast_rhs_scalar()
# test_ndarray_isub() test_ndarray_isub()
# test_ndarray_isub_broadcast() test_ndarray_isub_broadcast()
# test_ndarray_isub_broadcast_scalar() test_ndarray_isub_broadcast_scalar()
# test_ndarray_mul() test_ndarray_mul()
# test_ndarray_mul_broadcast() test_ndarray_mul_broadcast()
# test_ndarray_mul_broadcast_lhs_scalar() test_ndarray_mul_broadcast_lhs_scalar()
# test_ndarray_mul_broadcast_rhs_scalar() test_ndarray_mul_broadcast_rhs_scalar()
# test_ndarray_imul() test_ndarray_imul()
# test_ndarray_imul_broadcast() test_ndarray_imul_broadcast()
# test_ndarray_imul_broadcast_scalar() test_ndarray_imul_broadcast_scalar()
# test_ndarray_truediv() test_ndarray_truediv()
# test_ndarray_truediv_broadcast() test_ndarray_truediv_broadcast()
# test_ndarray_truediv_broadcast_lhs_scalar() test_ndarray_truediv_broadcast_lhs_scalar()
# test_ndarray_truediv_broadcast_rhs_scalar() test_ndarray_truediv_broadcast_rhs_scalar()
# test_ndarray_itruediv() test_ndarray_itruediv()
# test_ndarray_itruediv_broadcast() test_ndarray_itruediv_broadcast()
# test_ndarray_itruediv_broadcast_scalar() test_ndarray_itruediv_broadcast_scalar()
# test_ndarray_floordiv() test_ndarray_floordiv()
# test_ndarray_floordiv_broadcast() test_ndarray_floordiv_broadcast()
# test_ndarray_floordiv_broadcast_lhs_scalar() test_ndarray_floordiv_broadcast_lhs_scalar()
# test_ndarray_floordiv_broadcast_rhs_scalar() test_ndarray_floordiv_broadcast_rhs_scalar()
# test_ndarray_ifloordiv() test_ndarray_ifloordiv()
# test_ndarray_ifloordiv_broadcast() test_ndarray_ifloordiv_broadcast()
# test_ndarray_ifloordiv_broadcast_scalar() test_ndarray_ifloordiv_broadcast_scalar()
# test_ndarray_mod() test_ndarray_mod()
# test_ndarray_mod_broadcast() test_ndarray_mod_broadcast()
# test_ndarray_mod_broadcast_lhs_scalar() test_ndarray_mod_broadcast_lhs_scalar()
# test_ndarray_mod_broadcast_rhs_scalar() test_ndarray_mod_broadcast_rhs_scalar()
# test_ndarray_imod() test_ndarray_imod()
# test_ndarray_imod_broadcast() test_ndarray_imod_broadcast()
# test_ndarray_imod_broadcast_scalar() test_ndarray_imod_broadcast_scalar()
# test_ndarray_pow() test_ndarray_pow()
# test_ndarray_pow_broadcast() test_ndarray_pow_broadcast()
# test_ndarray_pow_broadcast_lhs_scalar() test_ndarray_pow_broadcast_lhs_scalar()
# test_ndarray_pow_broadcast_rhs_scalar() test_ndarray_pow_broadcast_rhs_scalar()
# test_ndarray_ipow() test_ndarray_ipow()
# test_ndarray_ipow_broadcast() test_ndarray_ipow_broadcast()
# test_ndarray_ipow_broadcast_scalar() test_ndarray_ipow_broadcast_scalar()
# test_ndarray_matmul() test_ndarray_matmul()
# test_ndarray_imatmul() test_ndarray_imatmul()
# test_ndarray_pos() test_ndarray_pos()
# test_ndarray_neg() test_ndarray_neg()
# test_ndarray_inv() test_ndarray_inv()
# test_ndarray_eq() test_ndarray_eq()
# test_ndarray_eq_broadcast() test_ndarray_eq_broadcast()
# test_ndarray_eq_broadcast_lhs_scalar() test_ndarray_eq_broadcast_lhs_scalar()
# test_ndarray_eq_broadcast_rhs_scalar() test_ndarray_eq_broadcast_rhs_scalar()
# test_ndarray_ne() test_ndarray_ne()
# test_ndarray_ne_broadcast() test_ndarray_ne_broadcast()
# test_ndarray_ne_broadcast_lhs_scalar() test_ndarray_ne_broadcast_lhs_scalar()
# test_ndarray_ne_broadcast_rhs_scalar() test_ndarray_ne_broadcast_rhs_scalar()
# test_ndarray_lt() test_ndarray_lt()
# test_ndarray_lt_broadcast() test_ndarray_lt_broadcast()
# test_ndarray_lt_broadcast_lhs_scalar() test_ndarray_lt_broadcast_lhs_scalar()
# test_ndarray_lt_broadcast_rhs_scalar() test_ndarray_lt_broadcast_rhs_scalar()
# test_ndarray_lt() test_ndarray_lt()
# test_ndarray_le_broadcast() test_ndarray_le_broadcast()
# test_ndarray_le_broadcast_lhs_scalar() test_ndarray_le_broadcast_lhs_scalar()
# test_ndarray_le_broadcast_rhs_scalar() test_ndarray_le_broadcast_rhs_scalar()
# test_ndarray_gt() test_ndarray_gt()
# test_ndarray_gt_broadcast() test_ndarray_gt_broadcast()
# test_ndarray_gt_broadcast_lhs_scalar() test_ndarray_gt_broadcast_lhs_scalar()
# test_ndarray_gt_broadcast_rhs_scalar() test_ndarray_gt_broadcast_rhs_scalar()
# test_ndarray_gt() test_ndarray_gt()
# test_ndarray_ge_broadcast() test_ndarray_ge_broadcast()
# test_ndarray_ge_broadcast_lhs_scalar() test_ndarray_ge_broadcast_lhs_scalar()
# test_ndarray_ge_broadcast_rhs_scalar() test_ndarray_ge_broadcast_rhs_scalar()
# test_ndarray_int32() test_ndarray_int32()
# test_ndarray_int64() test_ndarray_int64()
# test_ndarray_uint32() test_ndarray_uint32()
# test_ndarray_uint64() test_ndarray_uint64()
# test_ndarray_float() test_ndarray_float()
# test_ndarray_bool() test_ndarray_bool()
# test_ndarray_round() test_ndarray_round()
# test_ndarray_floor() test_ndarray_floor()
# test_ndarray_min() test_ndarray_min()
# test_ndarray_minimum() test_ndarray_minimum()
# test_ndarray_minimum_broadcast() test_ndarray_minimum_broadcast()
# test_ndarray_minimum_broadcast_lhs_scalar() test_ndarray_minimum_broadcast_lhs_scalar()
# test_ndarray_minimum_broadcast_rhs_scalar() test_ndarray_minimum_broadcast_rhs_scalar()
# test_ndarray_argmin() test_ndarray_argmin()
# test_ndarray_max() test_ndarray_max()
# test_ndarray_maximum() test_ndarray_maximum()
# test_ndarray_maximum_broadcast() test_ndarray_maximum_broadcast()
# test_ndarray_maximum_broadcast_lhs_scalar() test_ndarray_maximum_broadcast_lhs_scalar()
# test_ndarray_maximum_broadcast_rhs_scalar() test_ndarray_maximum_broadcast_rhs_scalar()
# test_ndarray_argmax() test_ndarray_argmax()
# test_ndarray_abs() test_ndarray_abs()
# test_ndarray_isnan() test_ndarray_isnan()
# test_ndarray_isinf() test_ndarray_isinf()
# test_ndarray_sin() test_ndarray_sin()
# test_ndarray_cos() test_ndarray_cos()
# test_ndarray_exp() test_ndarray_exp()
# test_ndarray_exp2() test_ndarray_exp2()
# test_ndarray_log() test_ndarray_log()
# test_ndarray_log10() test_ndarray_log10()
# test_ndarray_log2() test_ndarray_log2()
# test_ndarray_fabs() test_ndarray_fabs()
# test_ndarray_sqrt() test_ndarray_sqrt()
# test_ndarray_rint() test_ndarray_rint()
# test_ndarray_tan() test_ndarray_tan()
# test_ndarray_arcsin() test_ndarray_arcsin()
# test_ndarray_arccos() test_ndarray_arccos()
# test_ndarray_arctan() test_ndarray_arctan()
# test_ndarray_sinh() test_ndarray_sinh()
# test_ndarray_cosh() test_ndarray_cosh()
# test_ndarray_tanh() test_ndarray_tanh()
# test_ndarray_arcsinh() test_ndarray_arcsinh()
# test_ndarray_arccosh() test_ndarray_arccosh()
# test_ndarray_arctanh() test_ndarray_arctanh()
# test_ndarray_expm1() test_ndarray_expm1()
# test_ndarray_cbrt() test_ndarray_cbrt()
# test_ndarray_erf() test_ndarray_erf()
# test_ndarray_erfc() test_ndarray_erfc()
# test_ndarray_gamma() test_ndarray_gamma()
# test_ndarray_gammaln() test_ndarray_gammaln()
# test_ndarray_j0() test_ndarray_j0()
# test_ndarray_j1() test_ndarray_j1()
# test_ndarray_arctan2() test_ndarray_arctan2()
# test_ndarray_arctan2_broadcast() test_ndarray_arctan2_broadcast()
# test_ndarray_arctan2_broadcast_lhs_scalar() test_ndarray_arctan2_broadcast_lhs_scalar()
# test_ndarray_arctan2_broadcast_rhs_scalar() test_ndarray_arctan2_broadcast_rhs_scalar()
# test_ndarray_copysign() test_ndarray_copysign()
# test_ndarray_copysign_broadcast() test_ndarray_copysign_broadcast()
# test_ndarray_copysign_broadcast_lhs_scalar() test_ndarray_copysign_broadcast_lhs_scalar()
# test_ndarray_copysign_broadcast_rhs_scalar() test_ndarray_copysign_broadcast_rhs_scalar()
# test_ndarray_fmax() test_ndarray_fmax()
# test_ndarray_fmax_broadcast() test_ndarray_fmax_broadcast()
# test_ndarray_fmax_broadcast_lhs_scalar() test_ndarray_fmax_broadcast_lhs_scalar()
# test_ndarray_fmax_broadcast_rhs_scalar() test_ndarray_fmax_broadcast_rhs_scalar()
# test_ndarray_fmin() test_ndarray_fmin()
# test_ndarray_fmin_broadcast() test_ndarray_fmin_broadcast()
# test_ndarray_fmin_broadcast_lhs_scalar() test_ndarray_fmin_broadcast_lhs_scalar()
# test_ndarray_fmin_broadcast_rhs_scalar() test_ndarray_fmin_broadcast_rhs_scalar()
# test_ndarray_ldexp() test_ndarray_ldexp()
# test_ndarray_ldexp_broadcast() test_ndarray_ldexp_broadcast()
# test_ndarray_ldexp_broadcast_lhs_scalar() test_ndarray_ldexp_broadcast_lhs_scalar()
# test_ndarray_ldexp_broadcast_rhs_scalar() test_ndarray_ldexp_broadcast_rhs_scalar()
# test_ndarray_hypot() test_ndarray_hypot()
# test_ndarray_hypot_broadcast() test_ndarray_hypot_broadcast()
# test_ndarray_hypot_broadcast_lhs_scalar() test_ndarray_hypot_broadcast_lhs_scalar()
# test_ndarray_hypot_broadcast_rhs_scalar() test_ndarray_hypot_broadcast_rhs_scalar()
# test_ndarray_nextafter() test_ndarray_nextafter()
# test_ndarray_nextafter_broadcast() test_ndarray_nextafter_broadcast()
# test_ndarray_nextafter_broadcast_lhs_scalar() test_ndarray_nextafter_broadcast_lhs_scalar()
# test_ndarray_nextafter_broadcast_rhs_scalar() test_ndarray_nextafter_broadcast_rhs_scalar()
# test_ndarray_transpose() test_ndarray_transpose()
test_ndarray_reshape() test_ndarray_reshape()
# test_ndarray_dot() test_ndarray_dot()
# test_ndarray_linalg_matmul() test_ndarray_linalg_matmul()
# test_ndarray_cholesky() test_ndarray_cholesky()
# test_ndarray_qr() test_ndarray_qr()
# test_ndarray_svd() test_ndarray_svd()
# test_ndarray_linalg_inv() test_ndarray_linalg_inv()
# test_ndarray_pinv() test_ndarray_pinv()
# test_ndarray_lu() test_ndarray_lu()
# test_ndarray_schur() test_ndarray_schur()
# test_ndarray_hessenberg() test_ndarray_hessenberg()
return 0 return 0