From 1c48d54afa85cfd46ef05dbdb040cfd6880aeacc Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 15 Aug 2024 15:13:02 +0800 Subject: [PATCH] WIP: core/ndstrides: fix nditer --- nac3core/irrt/irrt/ndarray/iter.hpp | 2 +- .../src/codegen/object/ndarray/factory.rs | 25 +++++++++++-------- nac3core/src/codegen/object/ndarray/mod.rs | 7 ++++-- nac3core/src/toplevel/builtins.rs | 22 ++++++++++------ 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/nac3core/irrt/irrt/ndarray/iter.hpp b/nac3core/irrt/irrt/ndarray/iter.hpp index 078646f9..0d59ba83 100644 --- a/nac3core/irrt/irrt/ndarray/iter.hpp +++ b/nac3core/irrt/irrt/ndarray/iter.hpp @@ -86,7 +86,7 @@ struct NDIter { indices[axis] = 0; // TODO: Can be optimized with backstrides. - element -= strides[axis] * shape[axis]; + element -= strides[axis] * (shape[axis] - 1); } else { element += strides[axis]; break; diff --git a/nac3core/src/codegen/object/ndarray/factory.rs b/nac3core/src/codegen/object/ndarray/factory.rs index 68294340..a191202d 100644 --- a/nac3core/src/codegen/object/ndarray/factory.rs +++ b/nac3core/src/codegen/object/ndarray/factory.rs @@ -32,7 +32,7 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { ctx.gen_string(generator, "").value.into() } else { - unreachable!() + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); } } @@ -40,28 +40,28 @@ fn ndarray_zero_value<'ctx, G: CodeGenerator + ?Sized>( fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, + dtype: Type, ) -> BasicValueEnum<'ctx> { if [ctx.primitives.int32, ctx.primitives.uint32] .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) + .any(|ty| ctx.unifier.unioned(dtype, *ty)) { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32); + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int32); ctx.ctx.i32_type().const_int(1, is_signed).into() } else if [ctx.primitives.int64, ctx.primitives.uint64] .iter() - .any(|ty| ctx.unifier.unioned(elem_ty, *ty)) + .any(|ty| ctx.unifier.unioned(dtype, *ty)) { - let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64); + let is_signed = ctx.unifier.unioned(dtype, ctx.primitives.int64); ctx.ctx.i64_type().const_int(1, is_signed).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { + } else if ctx.unifier.unioned(dtype, ctx.primitives.float) { ctx.ctx.f64_type().const_float(1.0).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { + } else if ctx.unifier.unioned(dtype, ctx.primitives.bool) { ctx.ctx.bool_type().const_int(1, false).into() - } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { + } else if ctx.unifier.unioned(dtype, ctx.primitives.str) { ctx.gen_string(generator, "1").value.into() } else { - unreachable!() + panic!("unrecognized dtype: {}", ctx.unifier.stringify(dtype)); } } @@ -192,6 +192,8 @@ impl<'ctx> NDArrayObject<'ctx> { "eye_ndarray", ); + // Create data and make the matrix like look np.eye() + ndarray.create_data(generator, ctx); ndarray .foreach(generator, ctx, |generator, ctx, _hooks, nditer| { // NOTE: rows and cols can never be zero here, since this ndarray's `np.size` would be zero @@ -218,7 +220,8 @@ impl<'ctx> NDArrayObject<'ctx> { Ok(()) }) .unwrap(); - todo!() + + ndarray } /// Create an ndarray like `np.identity`. diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 19fc055e..7bb5e331 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -649,9 +649,12 @@ impl<'ctx> NDArrayObject<'ctx> { NDArrayObject::alloca(generator, ctx, self.dtype, new_ndims, "reshaped_ndarray"); dst_ndarray.copy_shape_from_array(generator, ctx, new_shape); + // Reolsve negative indices let size = self.size(generator, ctx); - let new_ndims = dst_ndarray.get_ndims(generator, ctx.ctx); - call_nac3_ndarray_resolve_and_check_new_shape(generator, ctx, size, new_ndims, new_shape); + let dst_ndims = dst_ndarray.get_ndims(generator, ctx.ctx); + let dst_shape = + dst_ndarray.instance.get(generator, ctx, |f| f.shape, "reshaped_ndarray_shape"); + call_nac3_ndarray_resolve_and_check_new_shape(generator, ctx, size, dst_ndims, dst_shape); let is_c_contiguous = self.is_c_contiguous(generator, ctx); ctx.builder.build_conditional_branch(is_c_contiguous.value, then_bb, else_bb).unwrap(); diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 33e7cef6..3b3e49d6 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1556,15 +1556,21 @@ impl<'a> BuiltinBuilder<'a> { }?; // Implementation - let sizet_model = IntModel(SizeT); - let nrows = - sizet_model.check_value(generator, ctx.ctx, nrows_arg).unwrap(); - let ncols = - sizet_model.check_value(generator, ctx.ctx, ncols_arg).unwrap(); - let offset = - sizet_model.check_value(generator, ctx.ctx, offset_arg).unwrap(); + let i32_model = IntModel(Int32); + let nrows = i32_model + .check_value(generator, ctx.ctx, nrows_arg) + .unwrap() + .s_extend_or_bit_cast(generator, ctx, SizeT, "nrows"); + let ncols = i32_model + .check_value(generator, ctx.ctx, ncols_arg) + .unwrap() + .s_extend_or_bit_cast(generator, ctx, SizeT, "ncols"); + let offset = i32_model + .check_value(generator, ctx.ctx, offset_arg) + .unwrap() + .s_extend_or_bit_cast(generator, ctx, SizeT, "offset"); - let (_, dtype) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); + let (dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); let ndarray = NDArrayObject::from_np_eye( generator, ctx, dtype, nrows, ncols, offset, );