From c71a567a515b271de5150f310e39a5d125a8943f Mon Sep 17 00:00:00 2001 From: = <=> Date: Wed, 31 Jul 2024 01:54:36 +0800 Subject: [PATCH] Reshape Completed --- nac3core/src/codegen/numpy.rs | 16 +- nac3core/src/toplevel/builtins.rs | 45 ++-- nac3standalone/demo/src/ndarray.py | 373 +++++++++++++++-------------- 3 files changed, 229 insertions(+), 205 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index ee19db1a..20d3ed8f 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -2252,7 +2252,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( llvm_usize.const_int(1, false), )?; let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); - let rem = ctx.builder.build_int_sub(n_sz, acc_val, "").unwrap(); + let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); create_ndarray_dyn_shape( generator, ctx, @@ -2286,7 +2286,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( } BasicValueEnum::StructValue(shape_tuple) => { let ndims = shape_tuple.get_type().count_fields(); - let mut acc_val = ndim_ty.const_int(1, false); + let acc = ctx.builder.build_alloca(ndim_ty, "").unwrap(); + ctx.builder.build_store(acc, ndim_ty.const_int(1, false)).unwrap(); for dim_i in 0..ndims { let dim = ctx @@ -2315,13 +2316,16 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( Ok(None) }, |_, ctx| { - acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap(); + let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); + let acc_val = ctx.builder.build_int_mul(acc_val, dim, "").unwrap(); + ctx.builder.build_store(acc, acc_val).unwrap(); Ok(None) }, )?; } - let rem = ctx.builder.build_int_sub(n_sz, acc_val, "").unwrap(); + let acc_val = ctx.builder.build_load(acc, "").unwrap().into_int_value(); + let rem = ctx.builder.build_int_unsigned_div(n_sz, acc_val, "").unwrap(); let mut shape = Vec::with_capacity(ndims as usize); for dim_i in 0..ndims { @@ -2397,8 +2401,8 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( generator, ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(), "0:ValueError", - "cannot reshape array of size {} into provided shape", - [Some(n_sz), None, None], + "cannot reshape array of size {} into provided shape of size {}", + [Some(n_sz), Some(out_sz), None], ctx.current_loc, ); diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 28197575..7ccdd1ef 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1913,20 +1913,37 @@ impl<'a> BuiltinBuilder<'a> { }), ), - PrimDef::FunNpReshape => create_fn_by_codegen( - self.unifier, - &var_map, - prim.name(), - ndarray_ty.ty, - &[(ndarray_ty.ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], - Box::new(move |ctx, _, fun, args, generator| { - let x1_ty = fun.0.args[0].ty; - let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; - let x2_ty = fun.0.args[1].ty; - let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) - }), - ), + PrimDef::FunNpReshape => { + // Return type can have differend ndims + let ret_ty = + self.unifier.get_fresh_var_with_range(&[ndarray_type], Some("U".into()), None); + let var_map = var_map + .clone() + .into_iter() + .chain(once((ret_ty.id, ret_ty.ty))) + .chain(once(( + self.ndarray_factory_fn_shape_arg_tvar.id, + self.ndarray_factory_fn_shape_arg_tvar.ty, + ))) + .collect::>(); + + create_fn_by_codegen( + self.unifier, + &var_map, + prim.name(), + ret_ty.ty, + &[(ndarray_ty.ty, "x"), (self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], + Box::new(move |ctx, _, fun, args, generator| { + let x1_ty = fun.0.args[0].ty; + let x1_val = + args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; + let x2_val = + args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; + Ok(Some(ndarray_reshape(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + }), + ) + } _ => unreachable!(), } diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index fdd4e046..b9600474 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -1444,11 +1444,14 @@ def test_ndarray_transpose(): output_ndarray_float_1(y) def test_ndarray_reshape(): - x: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) - y = np_reshape(x, [2, 5]) + w: ndarray[float, 1] = np_array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + x: ndarray[float, 4] = np_reshape(w, (1, 2, 1, -1)) + y: ndarray[float, 2] = np_reshape(x, [2, -1]) + z: ndarray[float, 1] = np_reshape(w, 10) - output_ndarray_float_1(x) - # output_ndarray_float_1(y) + output_ndarray_float_1(w) + output_ndarray_float_2(y) + output_ndarray_float_1(z) def test_ndarray_dot(): x: ndarray[float, 1] = np_array([5.0, 1.0]) @@ -1549,194 +1552,194 @@ def test_ndarray_svd(): def run() -> int32: - # test_ndarray_ctor() - # test_ndarray_empty() - # test_ndarray_zeros() - # test_ndarray_ones() - # test_ndarray_full() - # test_ndarray_eye() - # test_ndarray_array() - # test_ndarray_identity() - # test_ndarray_fill() - # test_ndarray_copy() + test_ndarray_ctor() + test_ndarray_empty() + test_ndarray_zeros() + test_ndarray_ones() + test_ndarray_full() + test_ndarray_eye() + test_ndarray_array() + test_ndarray_identity() + test_ndarray_fill() + test_ndarray_copy() - # test_ndarray_neg_idx() - # test_ndarray_slices() - # test_ndarray_nd_idx() + test_ndarray_neg_idx() + test_ndarray_slices() + test_ndarray_nd_idx() - # test_ndarray_add() - # test_ndarray_add_broadcast() - # test_ndarray_add_broadcast_lhs_scalar() - # test_ndarray_add_broadcast_rhs_scalar() - # test_ndarray_iadd() - # test_ndarray_iadd_broadcast() - # test_ndarray_iadd_broadcast_scalar() - # test_ndarray_sub() - # test_ndarray_sub_broadcast() - # test_ndarray_sub_broadcast_lhs_scalar() - # test_ndarray_sub_broadcast_rhs_scalar() - # test_ndarray_isub() - # test_ndarray_isub_broadcast() - # test_ndarray_isub_broadcast_scalar() - # test_ndarray_mul() - # test_ndarray_mul_broadcast() - # test_ndarray_mul_broadcast_lhs_scalar() - # test_ndarray_mul_broadcast_rhs_scalar() - # test_ndarray_imul() - # test_ndarray_imul_broadcast() - # test_ndarray_imul_broadcast_scalar() - # test_ndarray_truediv() - # test_ndarray_truediv_broadcast() - # test_ndarray_truediv_broadcast_lhs_scalar() - # test_ndarray_truediv_broadcast_rhs_scalar() - # test_ndarray_itruediv() - # test_ndarray_itruediv_broadcast() - # test_ndarray_itruediv_broadcast_scalar() - # test_ndarray_floordiv() - # test_ndarray_floordiv_broadcast() - # test_ndarray_floordiv_broadcast_lhs_scalar() - # test_ndarray_floordiv_broadcast_rhs_scalar() - # test_ndarray_ifloordiv() - # test_ndarray_ifloordiv_broadcast() - # test_ndarray_ifloordiv_broadcast_scalar() - # test_ndarray_mod() - # test_ndarray_mod_broadcast() - # test_ndarray_mod_broadcast_lhs_scalar() - # test_ndarray_mod_broadcast_rhs_scalar() - # test_ndarray_imod() - # test_ndarray_imod_broadcast() - # test_ndarray_imod_broadcast_scalar() - # test_ndarray_pow() - # test_ndarray_pow_broadcast() - # test_ndarray_pow_broadcast_lhs_scalar() - # test_ndarray_pow_broadcast_rhs_scalar() - # test_ndarray_ipow() - # test_ndarray_ipow_broadcast() - # test_ndarray_ipow_broadcast_scalar() - # test_ndarray_matmul() - # test_ndarray_imatmul() - # test_ndarray_pos() - # test_ndarray_neg() - # test_ndarray_inv() - # test_ndarray_eq() - # test_ndarray_eq_broadcast() - # test_ndarray_eq_broadcast_lhs_scalar() - # test_ndarray_eq_broadcast_rhs_scalar() - # test_ndarray_ne() - # test_ndarray_ne_broadcast() - # test_ndarray_ne_broadcast_lhs_scalar() - # test_ndarray_ne_broadcast_rhs_scalar() - # test_ndarray_lt() - # test_ndarray_lt_broadcast() - # test_ndarray_lt_broadcast_lhs_scalar() - # test_ndarray_lt_broadcast_rhs_scalar() - # test_ndarray_lt() - # test_ndarray_le_broadcast() - # test_ndarray_le_broadcast_lhs_scalar() - # test_ndarray_le_broadcast_rhs_scalar() - # test_ndarray_gt() - # test_ndarray_gt_broadcast() - # test_ndarray_gt_broadcast_lhs_scalar() - # test_ndarray_gt_broadcast_rhs_scalar() - # test_ndarray_gt() - # test_ndarray_ge_broadcast() - # test_ndarray_ge_broadcast_lhs_scalar() - # test_ndarray_ge_broadcast_rhs_scalar() + test_ndarray_add() + test_ndarray_add_broadcast() + test_ndarray_add_broadcast_lhs_scalar() + test_ndarray_add_broadcast_rhs_scalar() + test_ndarray_iadd() + test_ndarray_iadd_broadcast() + test_ndarray_iadd_broadcast_scalar() + test_ndarray_sub() + test_ndarray_sub_broadcast() + test_ndarray_sub_broadcast_lhs_scalar() + test_ndarray_sub_broadcast_rhs_scalar() + test_ndarray_isub() + test_ndarray_isub_broadcast() + test_ndarray_isub_broadcast_scalar() + test_ndarray_mul() + test_ndarray_mul_broadcast() + test_ndarray_mul_broadcast_lhs_scalar() + test_ndarray_mul_broadcast_rhs_scalar() + test_ndarray_imul() + test_ndarray_imul_broadcast() + test_ndarray_imul_broadcast_scalar() + test_ndarray_truediv() + test_ndarray_truediv_broadcast() + test_ndarray_truediv_broadcast_lhs_scalar() + test_ndarray_truediv_broadcast_rhs_scalar() + test_ndarray_itruediv() + test_ndarray_itruediv_broadcast() + test_ndarray_itruediv_broadcast_scalar() + test_ndarray_floordiv() + test_ndarray_floordiv_broadcast() + test_ndarray_floordiv_broadcast_lhs_scalar() + test_ndarray_floordiv_broadcast_rhs_scalar() + test_ndarray_ifloordiv() + test_ndarray_ifloordiv_broadcast() + test_ndarray_ifloordiv_broadcast_scalar() + test_ndarray_mod() + test_ndarray_mod_broadcast() + test_ndarray_mod_broadcast_lhs_scalar() + test_ndarray_mod_broadcast_rhs_scalar() + test_ndarray_imod() + test_ndarray_imod_broadcast() + test_ndarray_imod_broadcast_scalar() + test_ndarray_pow() + test_ndarray_pow_broadcast() + test_ndarray_pow_broadcast_lhs_scalar() + test_ndarray_pow_broadcast_rhs_scalar() + test_ndarray_ipow() + test_ndarray_ipow_broadcast() + test_ndarray_ipow_broadcast_scalar() + test_ndarray_matmul() + test_ndarray_imatmul() + test_ndarray_pos() + test_ndarray_neg() + test_ndarray_inv() + test_ndarray_eq() + test_ndarray_eq_broadcast() + test_ndarray_eq_broadcast_lhs_scalar() + test_ndarray_eq_broadcast_rhs_scalar() + test_ndarray_ne() + test_ndarray_ne_broadcast() + test_ndarray_ne_broadcast_lhs_scalar() + test_ndarray_ne_broadcast_rhs_scalar() + test_ndarray_lt() + test_ndarray_lt_broadcast() + test_ndarray_lt_broadcast_lhs_scalar() + test_ndarray_lt_broadcast_rhs_scalar() + test_ndarray_lt() + test_ndarray_le_broadcast() + test_ndarray_le_broadcast_lhs_scalar() + test_ndarray_le_broadcast_rhs_scalar() + test_ndarray_gt() + test_ndarray_gt_broadcast() + test_ndarray_gt_broadcast_lhs_scalar() + test_ndarray_gt_broadcast_rhs_scalar() + test_ndarray_gt() + test_ndarray_ge_broadcast() + test_ndarray_ge_broadcast_lhs_scalar() + test_ndarray_ge_broadcast_rhs_scalar() - # test_ndarray_int32() - # test_ndarray_int64() - # test_ndarray_uint32() - # test_ndarray_uint64() - # test_ndarray_float() - # test_ndarray_bool() + test_ndarray_int32() + test_ndarray_int64() + test_ndarray_uint32() + test_ndarray_uint64() + test_ndarray_float() + test_ndarray_bool() - # test_ndarray_round() - # test_ndarray_floor() - # test_ndarray_min() - # test_ndarray_minimum() - # test_ndarray_minimum_broadcast() - # test_ndarray_minimum_broadcast_lhs_scalar() - # test_ndarray_minimum_broadcast_rhs_scalar() - # test_ndarray_argmin() - # test_ndarray_max() - # test_ndarray_maximum() - # test_ndarray_maximum_broadcast() - # test_ndarray_maximum_broadcast_lhs_scalar() - # test_ndarray_maximum_broadcast_rhs_scalar() - # test_ndarray_argmax() - # test_ndarray_abs() - # test_ndarray_isnan() - # test_ndarray_isinf() + test_ndarray_round() + test_ndarray_floor() + test_ndarray_min() + test_ndarray_minimum() + test_ndarray_minimum_broadcast() + test_ndarray_minimum_broadcast_lhs_scalar() + test_ndarray_minimum_broadcast_rhs_scalar() + test_ndarray_argmin() + test_ndarray_max() + test_ndarray_maximum() + test_ndarray_maximum_broadcast() + test_ndarray_maximum_broadcast_lhs_scalar() + test_ndarray_maximum_broadcast_rhs_scalar() + test_ndarray_argmax() + test_ndarray_abs() + test_ndarray_isnan() + test_ndarray_isinf() - # test_ndarray_sin() - # test_ndarray_cos() - # test_ndarray_exp() - # test_ndarray_exp2() - # test_ndarray_log() - # test_ndarray_log10() - # test_ndarray_log2() - # test_ndarray_fabs() - # test_ndarray_sqrt() - # test_ndarray_rint() - # test_ndarray_tan() - # test_ndarray_arcsin() - # test_ndarray_arccos() - # test_ndarray_arctan() - # test_ndarray_sinh() - # test_ndarray_cosh() - # test_ndarray_tanh() - # test_ndarray_arcsinh() - # test_ndarray_arccosh() - # test_ndarray_arctanh() - # test_ndarray_expm1() - # test_ndarray_cbrt() + test_ndarray_sin() + test_ndarray_cos() + test_ndarray_exp() + test_ndarray_exp2() + test_ndarray_log() + test_ndarray_log10() + test_ndarray_log2() + test_ndarray_fabs() + test_ndarray_sqrt() + test_ndarray_rint() + test_ndarray_tan() + test_ndarray_arcsin() + test_ndarray_arccos() + test_ndarray_arctan() + test_ndarray_sinh() + test_ndarray_cosh() + test_ndarray_tanh() + test_ndarray_arcsinh() + test_ndarray_arccosh() + test_ndarray_arctanh() + test_ndarray_expm1() + test_ndarray_cbrt() - # test_ndarray_erf() - # test_ndarray_erfc() - # test_ndarray_gamma() - # test_ndarray_gammaln() - # test_ndarray_j0() - # test_ndarray_j1() + test_ndarray_erf() + test_ndarray_erfc() + test_ndarray_gamma() + test_ndarray_gammaln() + test_ndarray_j0() + test_ndarray_j1() - # test_ndarray_arctan2() - # test_ndarray_arctan2_broadcast() - # test_ndarray_arctan2_broadcast_lhs_scalar() - # test_ndarray_arctan2_broadcast_rhs_scalar() - # test_ndarray_copysign() - # test_ndarray_copysign_broadcast() - # test_ndarray_copysign_broadcast_lhs_scalar() - # test_ndarray_copysign_broadcast_rhs_scalar() - # test_ndarray_fmax() - # test_ndarray_fmax_broadcast() - # test_ndarray_fmax_broadcast_lhs_scalar() - # test_ndarray_fmax_broadcast_rhs_scalar() - # test_ndarray_fmin() - # test_ndarray_fmin_broadcast() - # test_ndarray_fmin_broadcast_lhs_scalar() - # test_ndarray_fmin_broadcast_rhs_scalar() - # test_ndarray_ldexp() - # test_ndarray_ldexp_broadcast() - # test_ndarray_ldexp_broadcast_lhs_scalar() - # test_ndarray_ldexp_broadcast_rhs_scalar() - # test_ndarray_hypot() - # test_ndarray_hypot_broadcast() - # test_ndarray_hypot_broadcast_lhs_scalar() - # test_ndarray_hypot_broadcast_rhs_scalar() - # test_ndarray_nextafter() - # test_ndarray_nextafter_broadcast() - # test_ndarray_nextafter_broadcast_lhs_scalar() - # test_ndarray_nextafter_broadcast_rhs_scalar() - # test_ndarray_transpose() + test_ndarray_arctan2() + test_ndarray_arctan2_broadcast() + test_ndarray_arctan2_broadcast_lhs_scalar() + test_ndarray_arctan2_broadcast_rhs_scalar() + test_ndarray_copysign() + test_ndarray_copysign_broadcast() + test_ndarray_copysign_broadcast_lhs_scalar() + test_ndarray_copysign_broadcast_rhs_scalar() + test_ndarray_fmax() + test_ndarray_fmax_broadcast() + test_ndarray_fmax_broadcast_lhs_scalar() + test_ndarray_fmax_broadcast_rhs_scalar() + test_ndarray_fmin() + test_ndarray_fmin_broadcast() + test_ndarray_fmin_broadcast_lhs_scalar() + test_ndarray_fmin_broadcast_rhs_scalar() + test_ndarray_ldexp() + test_ndarray_ldexp_broadcast() + test_ndarray_ldexp_broadcast_lhs_scalar() + test_ndarray_ldexp_broadcast_rhs_scalar() + test_ndarray_hypot() + test_ndarray_hypot_broadcast() + test_ndarray_hypot_broadcast_lhs_scalar() + test_ndarray_hypot_broadcast_rhs_scalar() + test_ndarray_nextafter() + test_ndarray_nextafter_broadcast() + test_ndarray_nextafter_broadcast_lhs_scalar() + test_ndarray_nextafter_broadcast_rhs_scalar() + test_ndarray_transpose() test_ndarray_reshape() - # test_ndarray_dot() - # test_ndarray_linalg_matmul() - # test_ndarray_cholesky() - # test_ndarray_qr() - # test_ndarray_svd() - # test_ndarray_linalg_inv() - # test_ndarray_pinv() - # test_ndarray_lu() - # test_ndarray_schur() - # test_ndarray_hessenberg() + test_ndarray_dot() + test_ndarray_linalg_matmul() + test_ndarray_cholesky() + test_ndarray_qr() + test_ndarray_svd() + test_ndarray_linalg_inv() + test_ndarray_pinv() + test_ndarray_lu() + test_ndarray_schur() + test_ndarray_hessenberg() return 0