forked from M-Labs/nac3
1
0
Fork 0
This commit is contained in:
lyken 2024-07-16 00:27:50 +08:00
parent 0946bd86ea
commit d90604b713
14 changed files with 368 additions and 130 deletions

View File

@ -126,6 +126,54 @@ namespace { namespace ndarray { namespace basic {
*dst_length = (SliceIndex) ndarray->shape[0];
}
// Copy data from one ndarray to another *OF THE EXACT SAME* ndims, shape, and itemsize.
template <typename SizeT>
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
__builtin_assume(src_ndarray->ndims == dst_ndarray->ndims);
__builtin_assume(src_ndarray->itemsize == dst_ndarray->itemsize);
for (SizeT i = 0; i < src_ndarray->size; i++) {
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
}
}
// `copy_data()` with assertions to check ndims, shape, and itemsize between the two ndarrays.
template <typename SizeT>
void copy_data_checked(ErrorContext* errctx, const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// NOTE: Out of all error types, runtime error seems appropriate
// Check ndims
if (src_ndarray->ndims != dst_ndarray->ndims) {
errctx->set_error(
errctx->error_ids->runtime_error,
"IRRT copy_data_checked input arrays `ndims` are mismatched"
);
return; // Terminate
}
// Check shape
if (!arrays_match(src_ndarray->ndims, src_ndarray->shape, dst_ndarray->shape)) {
errctx->set_error(
errctx->error_ids->runtime_error,
"IRRT copy_data_checked input arrays `shape` are mismatched"
);
return; // Terminate
}
// Check itemsize
if (src_ndarray->itemsize != dst_ndarray->itemsize) {
errctx->set_error(
errctx->error_ids->runtime_error,
"IRRT copy_data_checked input arrays `itemsize` are mismatched"
);
return; // Terminate
}
copy_data(src_ndarray, dst_ndarray);
}
} } }
extern "C" {

View File

@ -3,7 +3,8 @@
namespace { namespace ndarray { namespace broadcast {
namespace util {
template <typename SizeT>
bool can_broadcast_shape_to(
void assert_broadcast_shape_to(
ErrorContext* errctx,
const SizeT target_ndims,
const SizeT* target_shape,
const SizeT src_ndims,
@ -20,23 +21,33 @@ namespace { namespace ndarray { namespace broadcast {
```
Other interesting examples to consider:
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) ... ok`
- `can_broadcast_shape_to([3], [3, 1]) == false`
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) ... ok`
In cases when the shapes contain zero(es):
- `can_broadcast_shape_to([0], [1]) == true`
- `can_broadcast_shape_to([0], [1]) ... ok`
- `can_broadcast_shape_to([0], [2]) == false`
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) ... ok`
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) ... ok`
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) ... ok`
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
*/
// This is essentially doing the following in Python:
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
// Target ndims must not be smaller than source ndims
// e.g., `np.broadcast_to(np.zeros((1, 1, 1, 1)), (1, ))` is prohibited by numpy
if (target_ndims < src_ndims) {
// Error copied from python by doing the `np.broadcast_to(np.zeros((1, 1, 1, 1)), (1, ))`
errctx->set_error(
errctx->error_ids->value_error,
"input operand has more dimensions than allowed by the axis remapping"
);
return; // Terminate
}
// Implements the rules in https://numpy.org/doc/stable/user/basics.broadcasting.html
for (SizeT i = 0; i < src_ndims; i++) {
SizeT target_axis = target_ndims - i - 1;
SizeT src_axis = src_ndims - i - 1;
@ -47,10 +58,18 @@ namespace { namespace ndarray { namespace broadcast {
SizeT src_dim = src_dim_exists ? src_shape[src_axis] : 1;
bool ok = src_dim == 1 || target_dim == src_dim;
if (!ok) return false;
if (!ok) {
// Error copied from python by doing `np.broadcast_to(np.zeros((3, 1)), (1, 1)),
// but this is the true numpy error:
// "ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (3,1) and requested shape (1,1)"
// TODO: we cannot show more than 3 parameters!!
errctx->set_error(
errctx->error_ids->value_error,
"operands could not be broadcast together with remapping shapes [original->remapped]"
);
return; // Terminate
}
}
return true;
}
}
@ -79,18 +98,20 @@ namespace { namespace ndarray { namespace broadcast {
// # This implementation will NOT support this assignment.
// ```
template <typename SizeT>
void broadcast_to(NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
void broadcast_to(ErrorContext* errctx, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
dst_ndarray->data = src_ndarray->data;
dst_ndarray->itemsize = src_ndarray->itemsize;
// irrt_assert(
// ndarray_util::can_broadcast_shape_to(
// dst_ndarray->ndims,
// dst_ndarray->shape,
// src_ndarray->ndims,
// src_ndarray->shape
// )
// );
ndarray::broadcast::util::assert_broadcast_shape_to(
errctx,
dst_ndarray->ndims,
dst_ndarray->shape,
src_ndarray->ndims,
src_ndarray->shape
);
if (errctx->has_error()) {
return; // Propagate error
}
SizeT stride_product = 1;
for (SizeT i = 0; i < max(src_ndarray->ndims, dst_ndarray->ndims); i++) {

View File

@ -6,8 +6,6 @@
#include <irrt/error_context.hpp>
namespace {
typedef uint32_t NumNDSubscriptsType;
typedef uint8_t NDSubscriptType;
const NDSubscriptType INPUT_SUBSCRIPT_TYPE_INDEX = 0;
@ -72,7 +70,7 @@ namespace { namespace ndarray { namespace subscript {
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->itemsize`
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
template <typename SizeT>
void subscript(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
void subscript(ErrorContext* errctx, SliceIndex num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
// REFERENCE CODE (check out `_index_helper` in `__getitem__`):
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
@ -84,7 +82,7 @@ namespace { namespace ndarray { namespace subscript {
SizeT src_axis = 0;
SizeT dst_axis = 0;
for (SizeT i = 0; i < num_subscripts; i++) {
for (SliceIndex i = 0; i < num_subscripts; i++) {
NDSubscript *ndsubscript = &subscripts[i];
if (ndsubscript->type == INPUT_SUBSCRIPT_TYPE_INDEX) {
// Handle when the ndsubscript is just a single (possibly negative) integer
@ -161,11 +159,11 @@ extern "C" {
ndarray::subscript::util::deduce_ndims_after_slicing(errctx, result, ndims, num_ndsubscripts, ndsubscripts);
}
void __nac3_ndarray_subscript(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) {
void __nac3_ndarray_subscript(ErrorContext* errctx, SliceIndex num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) {
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
}
void __nac3_ndarray_subscript64(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) {
void __nac3_ndarray_subscript64(ErrorContext* errctx, SliceIndex num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) {
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
}
}

View File

@ -11,6 +11,7 @@
#include <test/test_core.hpp>
#include <test/test_ndarray_basic.hpp>
#include <test/test_ndarray_subscript.hpp>
#include <test/test_ndarray_broadcast.hpp>
#include <test/test_slice.hpp>
int main() {
@ -19,5 +20,6 @@ int main() {
test::slice::run();
test::ndarray_basic::run();
test::ndarray_subscript::run();
test::ndarray_broadcast::run();
return 0;
}

View File

@ -0,0 +1,72 @@
#pragma once
#include <test/core.hpp>
#include <irrt_everything.hpp>
namespace test { namespace ndarray_broadcast {
void test_ndarray_broadcast_1() {
/*
```python
array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
>>> [[19.9 29.9 39.9 49.9]]
array = np.broadcast_to(array, (2, 3, 4))
>>> [[[19.9 29.9 39.9 49.9]
>>> [19.9 29.9 39.9 49.9]
>>> [19.9 29.9 39.9 49.9]]
>>> [[19.9 29.9 39.9 49.9]
>>> [19.9 29.9 39.9 49.9]
>>> [19.9 29.9 39.9 49.9]]]
assert array.strides == (0, 0, 8)
# and then pick some values in `array` and check them...
```
*/
BEGIN_TEST();
// Prepare src_ndarray
double src_data[4] = { 19.9, 29.9, 39.9, 49.9 };
const int32_t src_ndims = 2;
int32_t src_shape[src_ndims] = {1, 4};
int32_t src_strides[src_ndims] = {};
NDArray<int32_t> src_ndarray = {
.data = (uint8_t*) src_data,
.itemsize = sizeof(double),
.ndims = src_ndims,
.shape = src_shape,
.strides = src_strides
};
ndarray::basic::set_strides_by_shape(&src_ndarray);
// Prepare dst_ndarray
const int32_t dst_ndims = 3;
int32_t dst_shape[dst_ndims] = {2, 3, 4};
int32_t dst_strides[dst_ndims] = {};
NDArray<int32_t> dst_ndarray = {
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides
};
// Broadcast
ErrorContext errctx = create_testing_errctx();
ndarray::broadcast::broadcast_to(&errctx, &src_ndarray, &dst_ndarray);
assert_errctx_no_error(&errctx);
assert_arrays_match(dst_ndims, ((int32_t[]) { 0, 0, 8 }), dst_ndarray.strides);
assert_values_match(19.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 0}))));
assert_values_match(29.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 1}))));
assert_values_match(39.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 2}))));
assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 3}))));
assert_values_match(19.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 0}))));
assert_values_match(29.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 1}))));
assert_values_match(39.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 2}))));
assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 3}))));
assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {1, 2, 3}))));
}
void run() {
test_ndarray_broadcast_1();
}
}}

View File

@ -2189,7 +2189,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
Ok(match value_expr {
None => None,
Some(value_expr) => Some(
slice_index_model.check_llvm_value(
slice_index_model.review(
generator
.gen_expr(ctx, value_expr)?
.unwrap()
@ -2209,7 +2209,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
// Anything else that is not a slice (might be illegal values),
// For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error
let index = slice_index_model.check_llvm_value(
let index = slice_index_model.review(
generator
.gen_expr(ctx, subscript_expr)?
.unwrap()
@ -2931,7 +2931,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
let ndarray_ptr_model = PointerModel(StructModel(NpArray { sizet }));
let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
ndarray_ptr_model.check_llvm_value(v.as_any_value_enum())
ndarray_ptr_model.review(v.as_any_value_enum())
} else {
return Ok(None);
};

View File

@ -156,7 +156,7 @@ pub fn call_nac3_ndarray_subscript_deduce_ndims_after_slicing<'ctx, G: CodeGener
pub fn call_nac3_ndarray_subscript<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
num_subscripts: FixedInt<'ctx, Int32>,
num_subscripts: SliceIndex<'ctx>,
subscripts: Pointer<'ctx, StructModel<NDSubscript>>,
src_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
dst_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
@ -171,7 +171,7 @@ pub fn call_nac3_ndarray_subscript<'ctx, G: CodeGenerator + ?Sized>(
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_subscript"),
)
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx_ptr)
.arg("num_subscripts", FixedIntModel(Int32), num_subscripts)
.arg("num_subscripts", SliceIndexModel::default(), num_subscripts)
.arg("subscripts", PointerModel(StructModel(NDSubscript)), subscripts)
.arg("src_ndarray", PointerModel(StructModel(NpArray { sizet })), src_ndarray)
.arg("dst_ndarray", PointerModel(StructModel(NpArray { sizet })), dst_ndarray)

View File

@ -61,7 +61,7 @@ impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> {
});
let ret = self.ctx.builder.build_call(function, &param_vals, name).unwrap();
return_model.check_llvm_value(ret.as_any_value_enum())
return_model.review(ret.as_any_value_enum())
}
// TODO: Code duplication, but otherwise returning<S: Optic<'ctx>> cannot resolve S if return_optic = None

View File

@ -22,11 +22,18 @@ pub trait ModelValue<'ctx>: Clone + Copy {
fn get_llvm_value(&self) -> BasicValueEnum<'ctx>;
}
pub trait Model<'ctx>: Clone + Copy {
// Should have been within [`Model<ctx>`],
// but rust object safety requirements made it necessary to
// split this interface out
pub trait CanCheckLLVMType {
fn check_llvm_type<'ctx>(&self, ctx: &'ctx Context) -> Result<(), String>;
}
pub trait Model<'ctx>: Clone + Copy + CanCheckLLVMType + Sized {
type Value: ModelValue<'ctx>;
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value;
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value;
fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Pointer<'ctx, Self> {
Pointer {

View File

@ -7,7 +7,7 @@ use itertools::Itertools;
use crate::codegen::CodeGenContext;
use super::{Model, ModelValue, Pointer};
use super::{core::CanCheckLLVMType, Model, ModelValue, Pointer};
#[derive(Debug, Clone, Copy)]
pub struct Field<E> {
@ -17,14 +17,12 @@ pub struct Field<E> {
}
// Like [`Field<E>`] but element must be [`BasicTypeEnum<'ctx>`]
#[derive(Debug, Clone, Copy)]
struct FieldLLVM<'ctx> {
gep_index: u64,
name: &'ctx str,
llvm_type: BasicTypeEnum<'ctx>,
llvm_type: Box<dyn CanCheckLLVMType>,
}
#[derive(Debug)]
pub struct FieldBuilder<'ctx> {
pub ctx: &'ctx Context,
gep_index_counter: u64,
@ -57,6 +55,33 @@ impl<'ctx> FieldBuilder<'ctx> {
}
}
fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> Result<(), String>
where
A: BasicType<'ctx>,
B: BasicType<'ctx>,
{
let expected = expected.as_basic_type_enum();
let got = got.as_basic_type_enum();
// Put those logic into here,
// otherwise there is always a fallback reporting on any kind of mismatch
match (expected, got) {
(BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => {
if expected.get_bit_width() != got.get_bit_width() {
return Err(format!(
"Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))"
));
}
}
(expected, got) => {
if expected != got {
return Err(format!("Expected {expected}, got {got}"));
}
}
}
Ok(())
}
pub trait IsStruct<'ctx>: Clone + Copy {
type Fields;
@ -75,7 +100,12 @@ pub trait IsStruct<'ctx>: Clone + Copy {
let field_types =
builder.fields.iter().map(|field_info| field_info.llvm_type).collect_vec();
ctx.struct_type(&field_types, false)
ctx.struct_type(&field_types, false).as_basic_type_enum().into_pointer_type().get_el
}
fn check_struct_type(&self) {
// Datatypes behind
// check_basic_types_match
}
}
@ -94,6 +124,12 @@ impl<'ctx, S: IsStruct<'ctx>> ModelValue<'ctx> for Struct<'ctx, S> {
}
}
impl<'ctx, S: IsStruct<'ctx>> CanCheckLLVMType<'ctx> for StructModel<S> {
fn check_llvm_type<'ctx>(&self, ctx: &'ctx Context) -> Result<(), String> {
todo!()
}
}
impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel<S> {
type Value = Struct<'ctx, S>; // TODO: enrich it
@ -101,7 +137,7 @@ impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel<S> {
self.0.get_struct_type(ctx).as_basic_type_enum()
}
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
// TODO: check structure
Struct { structure: self.0, value: value.into_struct_value() }
}

View File

@ -27,7 +27,7 @@ impl<'ctx> Model<'ctx> for IntModel<'ctx> {
self.0.as_basic_type_enum()
}
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
let int = value.into_int_value();
assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width());
Int(int)
@ -130,7 +130,7 @@ impl<'ctx, T: IsFixedInt> Model<'ctx> for FixedIntModel<T> {
T::get_int_type(ctx).as_basic_type_enum()
}
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
let value = value.into_int_value();
assert_eq!(value.get_type().get_bit_width(), T::get_bit_width());
FixedInt { int: self.0, value }

View File

@ -31,7 +31,7 @@ impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> {
pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> E::Value {
let val = ctx.builder.build_load(self.value, name).unwrap();
self.element.check_llvm_value(val.as_any_value_enum())
self.element.review(val.as_any_value_enum())
}
pub fn to_opaque(self) -> OpaquePointer<'ctx> {
@ -66,7 +66,7 @@ impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel<E> {
self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum()
}
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
// TODO: Check get_element_type()? for LLVM 14 at least...
Pointer { element: self.0, value: value.into_pointer_value() }
}
@ -92,7 +92,7 @@ impl<'ctx> Model<'ctx> for OpaquePointerModel {
ctx.i8_type().ptr_type(AddressSpace::default()).as_basic_type_enum()
}
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
let ptr = value.into_pointer_value();
// TODO: remove this check once LLVM pointers do not have `get_element_type()`
assert_eq!(ptr.get_type().get_element_type().into_int_type().get_bit_width(), 8);

View File

@ -189,10 +189,6 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
v.data().ptr_offset(ctx, generator, &index, name)
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
todo!()
}
_ => unreachable!(),
}
}
@ -207,10 +203,26 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
target: &Expr<Option<Type>>,
value: ValueEnum<'ctx>,
) -> Result<(), String> {
/*
To handle assignment statements `target = value`, with
special care taken for targets `gen_store_target` cannot handle, these are:
- Case 1. target is a Tuple
- e.g., `(x, y, z, w) = value`
- Case 2. *Sliced* list assignment `list.__setitem__`
- e.g., `my_list[1:3] = [100, 101]`, BUT NOT `my_list[0] = 99` (gen_store_target knows how to handle these),
- Case 3. Indexed ndarray assignment `ndarray.__setitem__`
- e.g., `my_ndarray[::-1, :] = 3`, `my_ndarray[:, 3::-1] = their_ndarray[10::2]`
- NOTE: Technically speaking, if `target` is sliced in such as way that it is referencing a
single element/scalar, we *could* implement gen_store_target for this special case;
but it is much, *much* simpler to generalize all indexed ndarray assignment without
special handling on that edgecase.
- Otherwise, use `gen_store_target`
*/
let llvm_usize = generator.get_size_type(ctx.ctx);
match &target.node {
ExprKind::Tuple { elts, .. } => {
if let ExprKind::Tuple { elts, .. } = &target.node {
// Handle Case 1. target is a Tuple
let BasicValueEnum::StructValue(v) =
value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
else {
@ -224,20 +236,39 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
.unwrap();
generator.gen_assign(ctx, elt, v.into())?;
}
return Ok(()); // Terminate
}
ExprKind::Subscript { value: ls, slice, .. }
if matches!(&slice.node, ExprKind::Slice { .. }) =>
// Else, try checking if it's Case 2 or 3, and they *ONLY*
// happen if `target.node` is a `ExprKind::Subscript`, so do a special check
if let ExprKind::Subscript { value: target_without_slice, slice, .. } = &target.node {
// Get the type of target
let target_ty = target.custom.unwrap();
let target_ty_enum = &*ctx.unifier.get_ty(target_ty);
// Pattern match on this pair.
// This is done like this because of Case 2 - slice.node has to be in a specific pattern
match (target_ty_enum, &slice.node) {
(TypeEnum::TObj { obj_id, .. }, ExprKind::Slice { lower, upper, step })
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let ExprKind::Slice { lower, upper, step } = &slice.node else { unreachable!() };
// Case 2. *Sliced* list assignment
let ls = generator
.gen_expr(ctx, ls)?
.gen_expr(ctx, target_without_slice)?
.unwrap()
.to_basic_value_enum(ctx, generator, ls.custom.unwrap())?
.to_basic_value_enum(ctx, generator, target_without_slice.custom.unwrap())?
.into_pointer_value();
let ls = ListValue::from_ptr_val(ls, llvm_usize, None);
let Some((start, end, step)) =
handle_slice_indices(lower, upper, step, ctx, generator, ls.load_size(ctx, None))?
let Some((start, end, step)) = handle_slice_indices(
lower,
upper,
step,
ctx,
generator,
ls.load_size(ctx, None),
)?
else {
return Ok(());
};
@ -268,8 +299,33 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
return Ok(());
};
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
return Ok(()); // Terminate
}
(TypeEnum::TObj { obj_id, .. }, _)
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
// Case 3. Indexed ndarray assignment
let target = generator.gen_expr(ctx, target)?.unwrap().to_basic_value_enum(
ctx,
generator,
target.custom.unwrap(),
);
// let value = value.to_basic_value_enum(ctx, generator, value);
todo!();
return Ok(()); // Terminate
}
_ => {
// Fallthrough
}
}
}
// None of the cases match. We should actually use `gen_store_target`.
let name = if let ExprKind::Name { id, .. } = &target.node {
format!("{id}.addr")
} else {
@ -288,8 +344,6 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
}
let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?;
ctx.builder.build_store(ptr, val).unwrap();
}
};
Ok(())
}

View File

@ -1470,7 +1470,7 @@ impl<'a> BuiltinBuilder<'a> {
let ndarray_ptr_model =
PointerModel(StructModel(NpArray { sizet }));
let ndarray_ptr =
ndarray_ptr_model.check_llvm_value(arg.as_any_value_enum());
ndarray_ptr_model.review(arg.as_any_value_enum());
// Calculate len
// NOTE: Unsized object is asserted in IRRT