From d90604b7132b0277db94e8fff24db39ef3e64cb3 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 16 Jul 2024 00:27:50 +0800 Subject: [PATCH] WIP --- nac3core/irrt/irrt/numpy/ndarray_basic.hpp | 48 ++++ .../irrt/irrt/numpy/ndarray_broadcast.hpp | 69 ++++-- .../irrt/irrt/numpy/ndarray_subscript.hpp | 10 +- nac3core/irrt/irrt_test.cpp | 2 + nac3core/irrt/test/test_ndarray_broadcast.hpp | 72 ++++++ nac3core/src/codegen/expr.rs | 6 +- nac3core/src/codegen/irrt/numpy/subscript.rs | 4 +- nac3core/src/codegen/irrt/util.rs | 2 +- nac3core/src/codegen/model/core.rs | 11 +- nac3core/src/codegen/model/gep.rs | 48 +++- nac3core/src/codegen/model/int.rs | 4 +- nac3core/src/codegen/model/pointer.rs | 6 +- nac3core/src/codegen/stmt.rs | 214 +++++++++++------- nac3core/src/toplevel/builtins.rs | 2 +- 14 files changed, 368 insertions(+), 130 deletions(-) create mode 100644 nac3core/irrt/test/test_ndarray_broadcast.hpp diff --git a/nac3core/irrt/irrt/numpy/ndarray_basic.hpp b/nac3core/irrt/irrt/numpy/ndarray_basic.hpp index 49b02eff..4f154fa6 100644 --- a/nac3core/irrt/irrt/numpy/ndarray_basic.hpp +++ b/nac3core/irrt/irrt/numpy/ndarray_basic.hpp @@ -126,6 +126,54 @@ namespace { namespace ndarray { namespace basic { *dst_length = (SliceIndex) ndarray->shape[0]; } + + // Copy data from one ndarray to another *OF THE EXACT SAME* ndims, shape, and itemsize. + template + void copy_data(const NDArray* src_ndarray, NDArray* dst_ndarray) { + __builtin_assume(src_ndarray->ndims == dst_ndarray->ndims); + __builtin_assume(src_ndarray->itemsize == dst_ndarray->itemsize); + + for (SizeT i = 0; i < src_ndarray->size; i++) { + auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i); + auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i); + ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element); + } + } + + // `copy_data()` with assertions to check ndims, shape, and itemsize between the two ndarrays. + template + void copy_data_checked(ErrorContext* errctx, const NDArray* src_ndarray, NDArray* dst_ndarray) { + // NOTE: Out of all error types, runtime error seems appropriate + + // Check ndims + if (src_ndarray->ndims != dst_ndarray->ndims) { + errctx->set_error( + errctx->error_ids->runtime_error, + "IRRT copy_data_checked input arrays `ndims` are mismatched" + ); + return; // Terminate + } + + // Check shape + if (!arrays_match(src_ndarray->ndims, src_ndarray->shape, dst_ndarray->shape)) { + errctx->set_error( + errctx->error_ids->runtime_error, + "IRRT copy_data_checked input arrays `shape` are mismatched" + ); + return; // Terminate + } + + // Check itemsize + if (src_ndarray->itemsize != dst_ndarray->itemsize) { + errctx->set_error( + errctx->error_ids->runtime_error, + "IRRT copy_data_checked input arrays `itemsize` are mismatched" + ); + return; // Terminate + } + + copy_data(src_ndarray, dst_ndarray); + } } } } extern "C" { diff --git a/nac3core/irrt/irrt/numpy/ndarray_broadcast.hpp b/nac3core/irrt/irrt/numpy/ndarray_broadcast.hpp index 4b4d9d57..2d368811 100644 --- a/nac3core/irrt/irrt/numpy/ndarray_broadcast.hpp +++ b/nac3core/irrt/irrt/numpy/ndarray_broadcast.hpp @@ -3,11 +3,12 @@ namespace { namespace ndarray { namespace broadcast { namespace util { template - bool can_broadcast_shape_to( + void assert_broadcast_shape_to( + ErrorContext* errctx, const SizeT target_ndims, - const SizeT *target_shape, + const SizeT* target_shape, const SizeT src_ndims, - const SizeT *src_shape + const SizeT* src_shape ) { /* // See https://numpy.org/doc/stable/user/basics.broadcasting.html @@ -20,23 +21,33 @@ namespace { namespace ndarray { namespace broadcast { ``` Other interesting examples to consider: - - `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true` + - `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) ... ok` - `can_broadcast_shape_to([3], [3, 1]) == false` - - `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true` + - `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) ... ok` In cases when the shapes contain zero(es): - - `can_broadcast_shape_to([0], [1]) == true` + - `can_broadcast_shape_to([0], [1]) ... ok` - `can_broadcast_shape_to([0], [2]) == false` - - `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true` - - `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true` - - `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true` + - `can_broadcast_shape_to([0, 4, 0, 0], [1]) ... ok` + - `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) ... ok` + - `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) ... ok` - `can_broadcast_shape_to([4, 3], [0, 3]) == false` - `can_broadcast_shape_to([4, 3], [0, 0]) == false` */ - // This is essentially doing the following in Python: - // `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)` - for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) { + // Target ndims must not be smaller than source ndims + // e.g., `np.broadcast_to(np.zeros((1, 1, 1, 1)), (1, ))` is prohibited by numpy + if (target_ndims < src_ndims) { + // Error copied from python by doing the `np.broadcast_to(np.zeros((1, 1, 1, 1)), (1, ))` + errctx->set_error( + errctx->error_ids->value_error, + "input operand has more dimensions than allowed by the axis remapping" + ); + return; // Terminate + } + + // Implements the rules in https://numpy.org/doc/stable/user/basics.broadcasting.html + for (SizeT i = 0; i < src_ndims; i++) { SizeT target_axis = target_ndims - i - 1; SizeT src_axis = src_ndims - i - 1; @@ -47,10 +58,18 @@ namespace { namespace ndarray { namespace broadcast { SizeT src_dim = src_dim_exists ? src_shape[src_axis] : 1; bool ok = src_dim == 1 || target_dim == src_dim; - if (!ok) return false; + if (!ok) { + // Error copied from python by doing `np.broadcast_to(np.zeros((3, 1)), (1, 1)), + // but this is the true numpy error: + // "ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (3,1) and requested shape (1,1)" + // TODO: we cannot show more than 3 parameters!! + errctx->set_error( + errctx->error_ids->value_error, + "operands could not be broadcast together with remapping shapes [original->remapped]" + ); + return; // Terminate + } } - - return true; } } @@ -79,18 +98,20 @@ namespace { namespace ndarray { namespace broadcast { // # This implementation will NOT support this assignment. // ``` template - void broadcast_to(NDArray* src_ndarray, NDArray* dst_ndarray) { + void broadcast_to(ErrorContext* errctx, NDArray* src_ndarray, NDArray* dst_ndarray) { dst_ndarray->data = src_ndarray->data; dst_ndarray->itemsize = src_ndarray->itemsize; - // irrt_assert( - // ndarray_util::can_broadcast_shape_to( - // dst_ndarray->ndims, - // dst_ndarray->shape, - // src_ndarray->ndims, - // src_ndarray->shape - // ) - // ); + ndarray::broadcast::util::assert_broadcast_shape_to( + errctx, + dst_ndarray->ndims, + dst_ndarray->shape, + src_ndarray->ndims, + src_ndarray->shape + ); + if (errctx->has_error()) { + return; // Propagate error + } SizeT stride_product = 1; for (SizeT i = 0; i < max(src_ndarray->ndims, dst_ndarray->ndims); i++) { diff --git a/nac3core/irrt/irrt/numpy/ndarray_subscript.hpp b/nac3core/irrt/irrt/numpy/ndarray_subscript.hpp index c4b13734..1926d1de 100644 --- a/nac3core/irrt/irrt/numpy/ndarray_subscript.hpp +++ b/nac3core/irrt/irrt/numpy/ndarray_subscript.hpp @@ -6,8 +6,6 @@ #include namespace { - typedef uint32_t NumNDSubscriptsType; - typedef uint8_t NDSubscriptType; const NDSubscriptType INPUT_SUBSCRIPT_TYPE_INDEX = 0; @@ -72,7 +70,7 @@ namespace { namespace ndarray { namespace subscript { // - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->itemsize` // - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values template - void subscript(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray* src_ndarray, NDArray* dst_ndarray) { + void subscript(ErrorContext* errctx, SliceIndex num_subscripts, NDSubscript* subscripts, NDArray* src_ndarray, NDArray* dst_ndarray) { // REFERENCE CODE (check out `_index_helper` in `__getitem__`): // https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652 @@ -84,7 +82,7 @@ namespace { namespace ndarray { namespace subscript { SizeT src_axis = 0; SizeT dst_axis = 0; - for (SizeT i = 0; i < num_subscripts; i++) { + for (SliceIndex i = 0; i < num_subscripts; i++) { NDSubscript *ndsubscript = &subscripts[i]; if (ndsubscript->type == INPUT_SUBSCRIPT_TYPE_INDEX) { // Handle when the ndsubscript is just a single (possibly negative) integer @@ -161,11 +159,11 @@ extern "C" { ndarray::subscript::util::deduce_ndims_after_slicing(errctx, result, ndims, num_ndsubscripts, ndsubscripts); } - void __nac3_ndarray_subscript(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray* src_ndarray, NDArray *dst_ndarray) { + void __nac3_ndarray_subscript(ErrorContext* errctx, SliceIndex num_subscripts, NDSubscript* subscripts, NDArray* src_ndarray, NDArray *dst_ndarray) { subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray); } - void __nac3_ndarray_subscript64(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray* src_ndarray, NDArray *dst_ndarray) { + void __nac3_ndarray_subscript64(ErrorContext* errctx, SliceIndex num_subscripts, NDSubscript* subscripts, NDArray* src_ndarray, NDArray *dst_ndarray) { subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray); } } \ No newline at end of file diff --git a/nac3core/irrt/irrt_test.cpp b/nac3core/irrt/irrt_test.cpp index 407e5d4c..7fc49ba2 100644 --- a/nac3core/irrt/irrt_test.cpp +++ b/nac3core/irrt/irrt_test.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include int main() { @@ -19,5 +20,6 @@ int main() { test::slice::run(); test::ndarray_basic::run(); test::ndarray_subscript::run(); + test::ndarray_broadcast::run(); return 0; } \ No newline at end of file diff --git a/nac3core/irrt/test/test_ndarray_broadcast.hpp b/nac3core/irrt/test/test_ndarray_broadcast.hpp new file mode 100644 index 00000000..b5f18b12 --- /dev/null +++ b/nac3core/irrt/test/test_ndarray_broadcast.hpp @@ -0,0 +1,72 @@ +#pragma once + +#include +#include + +namespace test { namespace ndarray_broadcast { + void test_ndarray_broadcast_1() { + /* + ```python + array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64) + >>> [[19.9 29.9 39.9 49.9]] + + array = np.broadcast_to(array, (2, 3, 4)) + >>> [[[19.9 29.9 39.9 49.9] + >>> [19.9 29.9 39.9 49.9] + >>> [19.9 29.9 39.9 49.9]] + >>> [[19.9 29.9 39.9 49.9] + >>> [19.9 29.9 39.9 49.9] + >>> [19.9 29.9 39.9 49.9]]] + + assert array.strides == (0, 0, 8) + # and then pick some values in `array` and check them... + ``` + */ + BEGIN_TEST(); + + // Prepare src_ndarray + double src_data[4] = { 19.9, 29.9, 39.9, 49.9 }; + const int32_t src_ndims = 2; + int32_t src_shape[src_ndims] = {1, 4}; + int32_t src_strides[src_ndims] = {}; + NDArray src_ndarray = { + .data = (uint8_t*) src_data, + .itemsize = sizeof(double), + .ndims = src_ndims, + .shape = src_shape, + .strides = src_strides + }; + ndarray::basic::set_strides_by_shape(&src_ndarray); + + // Prepare dst_ndarray + const int32_t dst_ndims = 3; + int32_t dst_shape[dst_ndims] = {2, 3, 4}; + int32_t dst_strides[dst_ndims] = {}; + NDArray dst_ndarray = { + .ndims = dst_ndims, + .shape = dst_shape, + .strides = dst_strides + }; + + // Broadcast + ErrorContext errctx = create_testing_errctx(); + ndarray::broadcast::broadcast_to(&errctx, &src_ndarray, &dst_ndarray); + assert_errctx_no_error(&errctx); + + assert_arrays_match(dst_ndims, ((int32_t[]) { 0, 0, 8 }), dst_ndarray.strides); + + assert_values_match(19.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 0})))); + assert_values_match(29.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 1})))); + assert_values_match(39.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 2})))); + assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 3})))); + assert_values_match(19.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 0})))); + assert_values_match(29.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 1})))); + assert_values_match(39.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 2})))); + assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 3})))); + assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {1, 2, 3})))); + } + + void run() { + test_ndarray_broadcast_1(); + } +}} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 430bd859..db1e9b66 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2189,7 +2189,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( Ok(match value_expr { None => None, Some(value_expr) => Some( - slice_index_model.check_llvm_value( + slice_index_model.review( generator .gen_expr(ctx, value_expr)? .unwrap() @@ -2209,7 +2209,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( // Anything else that is not a slice (might be illegal values), // For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error - let index = slice_index_model.check_llvm_value( + let index = slice_index_model.review( generator .gen_expr(ctx, subscript_expr)? .unwrap() @@ -2931,7 +2931,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( let ndarray_ptr_model = PointerModel(StructModel(NpArray { sizet })); let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?; - ndarray_ptr_model.check_llvm_value(v.as_any_value_enum()) + ndarray_ptr_model.review(v.as_any_value_enum()) } else { return Ok(None); }; diff --git a/nac3core/src/codegen/irrt/numpy/subscript.rs b/nac3core/src/codegen/irrt/numpy/subscript.rs index d59fa620..ce26b35b 100644 --- a/nac3core/src/codegen/irrt/numpy/subscript.rs +++ b/nac3core/src/codegen/irrt/numpy/subscript.rs @@ -156,7 +156,7 @@ pub fn call_nac3_ndarray_subscript_deduce_ndims_after_slicing<'ctx, G: CodeGener pub fn call_nac3_ndarray_subscript<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - num_subscripts: FixedInt<'ctx, Int32>, + num_subscripts: SliceIndex<'ctx>, subscripts: Pointer<'ctx, StructModel>, src_ndarray: Pointer<'ctx, StructModel>>, dst_ndarray: Pointer<'ctx, StructModel>>, @@ -171,7 +171,7 @@ pub fn call_nac3_ndarray_subscript<'ctx, G: CodeGenerator + ?Sized>( &get_sized_dependent_function_name(sizet, "__nac3_ndarray_subscript"), ) .arg("errctx", PointerModel(StructModel(ErrorContext)), errctx_ptr) - .arg("num_subscripts", FixedIntModel(Int32), num_subscripts) + .arg("num_subscripts", SliceIndexModel::default(), num_subscripts) .arg("subscripts", PointerModel(StructModel(NDSubscript)), subscripts) .arg("src_ndarray", PointerModel(StructModel(NpArray { sizet })), src_ndarray) .arg("dst_ndarray", PointerModel(StructModel(NpArray { sizet })), dst_ndarray) diff --git a/nac3core/src/codegen/irrt/util.rs b/nac3core/src/codegen/irrt/util.rs index 97466285..4b27ecb9 100644 --- a/nac3core/src/codegen/irrt/util.rs +++ b/nac3core/src/codegen/irrt/util.rs @@ -61,7 +61,7 @@ impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> { }); let ret = self.ctx.builder.build_call(function, ¶m_vals, name).unwrap(); - return_model.check_llvm_value(ret.as_any_value_enum()) + return_model.review(ret.as_any_value_enum()) } // TODO: Code duplication, but otherwise returning> cannot resolve S if return_optic = None diff --git a/nac3core/src/codegen/model/core.rs b/nac3core/src/codegen/model/core.rs index 96b08ccd..13c4830b 100644 --- a/nac3core/src/codegen/model/core.rs +++ b/nac3core/src/codegen/model/core.rs @@ -22,11 +22,18 @@ pub trait ModelValue<'ctx>: Clone + Copy { fn get_llvm_value(&self) -> BasicValueEnum<'ctx>; } -pub trait Model<'ctx>: Clone + Copy { +// Should have been within [`Model`], +// but rust object safety requirements made it necessary to +// split this interface out +pub trait CanCheckLLVMType { + fn check_llvm_type<'ctx>(&self, ctx: &'ctx Context) -> Result<(), String>; +} + +pub trait Model<'ctx>: Clone + Copy + CanCheckLLVMType + Sized { type Value: ModelValue<'ctx>; fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>; - fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value; + fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value; fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Pointer<'ctx, Self> { Pointer { diff --git a/nac3core/src/codegen/model/gep.rs b/nac3core/src/codegen/model/gep.rs index 44924e6f..14ac7cf0 100644 --- a/nac3core/src/codegen/model/gep.rs +++ b/nac3core/src/codegen/model/gep.rs @@ -7,7 +7,7 @@ use itertools::Itertools; use crate::codegen::CodeGenContext; -use super::{Model, ModelValue, Pointer}; +use super::{core::CanCheckLLVMType, Model, ModelValue, Pointer}; #[derive(Debug, Clone, Copy)] pub struct Field { @@ -17,14 +17,12 @@ pub struct Field { } // Like [`Field`] but element must be [`BasicTypeEnum<'ctx>`] -#[derive(Debug, Clone, Copy)] struct FieldLLVM<'ctx> { gep_index: u64, name: &'ctx str, - llvm_type: BasicTypeEnum<'ctx>, + llvm_type: Box, } -#[derive(Debug)] pub struct FieldBuilder<'ctx> { pub ctx: &'ctx Context, gep_index_counter: u64, @@ -57,6 +55,33 @@ impl<'ctx> FieldBuilder<'ctx> { } } +fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> Result<(), String> +where + A: BasicType<'ctx>, + B: BasicType<'ctx>, +{ + let expected = expected.as_basic_type_enum(); + let got = got.as_basic_type_enum(); + + // Put those logic into here, + // otherwise there is always a fallback reporting on any kind of mismatch + match (expected, got) { + (BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => { + if expected.get_bit_width() != got.get_bit_width() { + return Err(format!( + "Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))" + )); + } + } + (expected, got) => { + if expected != got { + return Err(format!("Expected {expected}, got {got}")); + } + } + } + Ok(()) +} + pub trait IsStruct<'ctx>: Clone + Copy { type Fields; @@ -75,7 +100,12 @@ pub trait IsStruct<'ctx>: Clone + Copy { let field_types = builder.fields.iter().map(|field_info| field_info.llvm_type).collect_vec(); - ctx.struct_type(&field_types, false) + ctx.struct_type(&field_types, false).as_basic_type_enum().into_pointer_type().get_el + } + + fn check_struct_type(&self) { + // Datatypes behind + // check_basic_types_match } } @@ -94,6 +124,12 @@ impl<'ctx, S: IsStruct<'ctx>> ModelValue<'ctx> for Struct<'ctx, S> { } } +impl<'ctx, S: IsStruct<'ctx>> CanCheckLLVMType<'ctx> for StructModel { + fn check_llvm_type<'ctx>(&self, ctx: &'ctx Context) -> Result<(), String> { + todo!() + } +} + impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel { type Value = Struct<'ctx, S>; // TODO: enrich it @@ -101,7 +137,7 @@ impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel { self.0.get_struct_type(ctx).as_basic_type_enum() } - fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { + fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value { // TODO: check structure Struct { structure: self.0, value: value.into_struct_value() } } diff --git a/nac3core/src/codegen/model/int.rs b/nac3core/src/codegen/model/int.rs index 1d3118b7..611c7bee 100644 --- a/nac3core/src/codegen/model/int.rs +++ b/nac3core/src/codegen/model/int.rs @@ -27,7 +27,7 @@ impl<'ctx> Model<'ctx> for IntModel<'ctx> { self.0.as_basic_type_enum() } - fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { + fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value { let int = value.into_int_value(); assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width()); Int(int) @@ -130,7 +130,7 @@ impl<'ctx, T: IsFixedInt> Model<'ctx> for FixedIntModel { T::get_int_type(ctx).as_basic_type_enum() } - fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { + fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value { let value = value.into_int_value(); assert_eq!(value.get_type().get_bit_width(), T::get_bit_width()); FixedInt { int: self.0, value } diff --git a/nac3core/src/codegen/model/pointer.rs b/nac3core/src/codegen/model/pointer.rs index 086e7834..969aa124 100644 --- a/nac3core/src/codegen/model/pointer.rs +++ b/nac3core/src/codegen/model/pointer.rs @@ -31,7 +31,7 @@ impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> { pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> E::Value { let val = ctx.builder.build_load(self.value, name).unwrap(); - self.element.check_llvm_value(val.as_any_value_enum()) + self.element.review(val.as_any_value_enum()) } pub fn to_opaque(self) -> OpaquePointer<'ctx> { @@ -66,7 +66,7 @@ impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel { self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum() } - fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { + fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value { // TODO: Check get_element_type()? for LLVM 14 at least... Pointer { element: self.0, value: value.into_pointer_value() } } @@ -92,7 +92,7 @@ impl<'ctx> Model<'ctx> for OpaquePointerModel { ctx.i8_type().ptr_type(AddressSpace::default()).as_basic_type_enum() } - fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value { + fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value { let ptr = value.into_pointer_value(); // TODO: remove this check once LLVM pointers do not have `get_element_type()` assert_eq!(ptr.get_type().get_element_type().into_int_type().get_bit_width(), 8); diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index cf16d3e5..ac965109 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -189,10 +189,6 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( v.data().ptr_offset(ctx, generator, &index, name) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - todo!() - } - _ => unreachable!(), } } @@ -207,89 +203,147 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( target: &Expr>, value: ValueEnum<'ctx>, ) -> Result<(), String> { + /* + To handle assignment statements `target = value`, with + special care taken for targets `gen_store_target` cannot handle, these are: + - Case 1. target is a Tuple + - e.g., `(x, y, z, w) = value` + - Case 2. *Sliced* list assignment `list.__setitem__` + - e.g., `my_list[1:3] = [100, 101]`, BUT NOT `my_list[0] = 99` (gen_store_target knows how to handle these), + - Case 3. Indexed ndarray assignment `ndarray.__setitem__` + - e.g., `my_ndarray[::-1, :] = 3`, `my_ndarray[:, 3::-1] = their_ndarray[10::2]` + - NOTE: Technically speaking, if `target` is sliced in such as way that it is referencing a + single element/scalar, we *could* implement gen_store_target for this special case; + but it is much, *much* simpler to generalize all indexed ndarray assignment without + special handling on that edgecase. + - Otherwise, use `gen_store_target` + */ + let llvm_usize = generator.get_size_type(ctx.ctx); - match &target.node { - ExprKind::Tuple { elts, .. } => { - let BasicValueEnum::StructValue(v) = - value.to_basic_value_enum(ctx, generator, target.custom.unwrap())? - else { - unreachable!() - }; + if let ExprKind::Tuple { elts, .. } = &target.node { + // Handle Case 1. target is a Tuple + let BasicValueEnum::StructValue(v) = + value.to_basic_value_enum(ctx, generator, target.custom.unwrap())? + else { + unreachable!() + }; - for (i, elt) in elts.iter().enumerate() { - let v = ctx - .builder - .build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem") - .unwrap(); - generator.gen_assign(ctx, elt, v.into())?; + for (i, elt) in elts.iter().enumerate() { + let v = ctx + .builder + .build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem") + .unwrap(); + generator.gen_assign(ctx, elt, v.into())?; + } + + return Ok(()); // Terminate + } + + // Else, try checking if it's Case 2 or 3, and they *ONLY* + // happen if `target.node` is a `ExprKind::Subscript`, so do a special check + if let ExprKind::Subscript { value: target_without_slice, slice, .. } = &target.node { + // Get the type of target + let target_ty = target.custom.unwrap(); + let target_ty_enum = &*ctx.unifier.get_ty(target_ty); + + // Pattern match on this pair. + // This is done like this because of Case 2 - slice.node has to be in a specific pattern + match (target_ty_enum, &slice.node) { + (TypeEnum::TObj { obj_id, .. }, ExprKind::Slice { lower, upper, step }) + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // Case 2. *Sliced* list assignment + + let ls = generator + .gen_expr(ctx, target_without_slice)? + .unwrap() + .to_basic_value_enum(ctx, generator, target_without_slice.custom.unwrap())? + .into_pointer_value(); + let ls = ListValue::from_ptr_val(ls, llvm_usize, None); + let Some((start, end, step)) = handle_slice_indices( + lower, + upper, + step, + ctx, + generator, + ls.load_size(ctx, None), + )? + else { + return Ok(()); + }; + let value = value + .to_basic_value_enum(ctx, generator, target.custom.unwrap())? + .into_pointer_value(); + let value = ListValue::from_ptr_val(value, llvm_usize, None); + let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { + TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { + *params.iter().next().unwrap().1 + } + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 + } + _ => unreachable!(), + }; + + let ty = ctx.get_llvm_type(generator, ty); + let Some(src_ind) = handle_slice_indices( + &None, + &None, + &None, + ctx, + generator, + value.load_size(ctx, None), + )? + else { + return Ok(()); + }; + list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind); + + return Ok(()); // Terminate + } + (TypeEnum::TObj { obj_id, .. }, _) + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + // Case 3. Indexed ndarray assignment + + let target = generator.gen_expr(ctx, target)?.unwrap().to_basic_value_enum( + ctx, + generator, + target.custom.unwrap(), + ); + + // let value = value.to_basic_value_enum(ctx, generator, value); + + todo!(); + + return Ok(()); // Terminate + } + _ => { + // Fallthrough } } - ExprKind::Subscript { value: ls, slice, .. } - if matches!(&slice.node, ExprKind::Slice { .. }) => - { - let ExprKind::Slice { lower, upper, step } = &slice.node else { unreachable!() }; + } - let ls = generator - .gen_expr(ctx, ls)? - .unwrap() - .to_basic_value_enum(ctx, generator, ls.custom.unwrap())? - .into_pointer_value(); - let ls = ListValue::from_ptr_val(ls, llvm_usize, None); - let Some((start, end, step)) = - handle_slice_indices(lower, upper, step, ctx, generator, ls.load_size(ctx, None))? - else { - return Ok(()); - }; - let value = value - .to_basic_value_enum(ctx, generator, target.custom.unwrap())? - .into_pointer_value(); - let value = ListValue::from_ptr_val(value, llvm_usize, None); - let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { - *params.iter().next().unwrap().1 - } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 - } - _ => unreachable!(), - }; - - let ty = ctx.get_llvm_type(generator, ty); - let Some(src_ind) = handle_slice_indices( - &None, - &None, - &None, - ctx, - generator, - value.load_size(ctx, None), - )? - else { - return Ok(()); - }; - list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind); - } - _ => { - let name = if let ExprKind::Name { id, .. } = &target.node { - format!("{id}.addr") - } else { - String::from("target.addr") - }; - let Some(ptr) = generator.gen_store_target(ctx, target, Some(name.as_str()))? else { - return Ok(()); - }; - - if let ExprKind::Name { id, .. } = &target.node { - let (_, static_value, counter) = ctx.var_assignment.get_mut(id).unwrap(); - *counter += 1; - if let ValueEnum::Static(s) = &value { - *static_value = Some(s.clone()); - } - } - let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?; - ctx.builder.build_store(ptr, val).unwrap(); - } + // None of the cases match. We should actually use `gen_store_target`. + let name = if let ExprKind::Name { id, .. } = &target.node { + format!("{id}.addr") + } else { + String::from("target.addr") }; + let Some(ptr) = generator.gen_store_target(ctx, target, Some(name.as_str()))? else { + return Ok(()); + }; + + if let ExprKind::Name { id, .. } = &target.node { + let (_, static_value, counter) = ctx.var_assignment.get_mut(id).unwrap(); + *counter += 1; + if let ValueEnum::Static(s) = &value { + *static_value = Some(s.clone()); + } + } + let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?; + ctx.builder.build_store(ptr, val).unwrap(); Ok(()) } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index cf089368..ef4038f7 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1470,7 +1470,7 @@ impl<'a> BuiltinBuilder<'a> { let ndarray_ptr_model = PointerModel(StructModel(NpArray { sizet })); let ndarray_ptr = - ndarray_ptr_model.check_llvm_value(arg.as_any_value_enum()); + ndarray_ptr_model.review(arg.as_any_value_enum()); // Calculate len // NOTE: Unsized object is asserted in IRRT