forked from M-Labs/nac3
WIP
This commit is contained in:
parent
0946bd86ea
commit
d90604b713
@ -126,6 +126,54 @@ namespace { namespace ndarray { namespace basic {
|
||||
|
||||
*dst_length = (SliceIndex) ndarray->shape[0];
|
||||
}
|
||||
|
||||
// Copy data from one ndarray to another *OF THE EXACT SAME* ndims, shape, and itemsize.
|
||||
template <typename SizeT>
|
||||
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
__builtin_assume(src_ndarray->ndims == dst_ndarray->ndims);
|
||||
__builtin_assume(src_ndarray->itemsize == dst_ndarray->itemsize);
|
||||
|
||||
for (SizeT i = 0; i < src_ndarray->size; i++) {
|
||||
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
|
||||
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
|
||||
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
|
||||
}
|
||||
}
|
||||
|
||||
// `copy_data()` with assertions to check ndims, shape, and itemsize between the two ndarrays.
|
||||
template <typename SizeT>
|
||||
void copy_data_checked(ErrorContext* errctx, const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
// NOTE: Out of all error types, runtime error seems appropriate
|
||||
|
||||
// Check ndims
|
||||
if (src_ndarray->ndims != dst_ndarray->ndims) {
|
||||
errctx->set_error(
|
||||
errctx->error_ids->runtime_error,
|
||||
"IRRT copy_data_checked input arrays `ndims` are mismatched"
|
||||
);
|
||||
return; // Terminate
|
||||
}
|
||||
|
||||
// Check shape
|
||||
if (!arrays_match(src_ndarray->ndims, src_ndarray->shape, dst_ndarray->shape)) {
|
||||
errctx->set_error(
|
||||
errctx->error_ids->runtime_error,
|
||||
"IRRT copy_data_checked input arrays `shape` are mismatched"
|
||||
);
|
||||
return; // Terminate
|
||||
}
|
||||
|
||||
// Check itemsize
|
||||
if (src_ndarray->itemsize != dst_ndarray->itemsize) {
|
||||
errctx->set_error(
|
||||
errctx->error_ids->runtime_error,
|
||||
"IRRT copy_data_checked input arrays `itemsize` are mismatched"
|
||||
);
|
||||
return; // Terminate
|
||||
}
|
||||
|
||||
copy_data(src_ndarray, dst_ndarray);
|
||||
}
|
||||
} } }
|
||||
|
||||
extern "C" {
|
||||
|
@ -3,11 +3,12 @@
|
||||
namespace { namespace ndarray { namespace broadcast {
|
||||
namespace util {
|
||||
template <typename SizeT>
|
||||
bool can_broadcast_shape_to(
|
||||
void assert_broadcast_shape_to(
|
||||
ErrorContext* errctx,
|
||||
const SizeT target_ndims,
|
||||
const SizeT *target_shape,
|
||||
const SizeT* target_shape,
|
||||
const SizeT src_ndims,
|
||||
const SizeT *src_shape
|
||||
const SizeT* src_shape
|
||||
) {
|
||||
/*
|
||||
// See https://numpy.org/doc/stable/user/basics.broadcasting.html
|
||||
@ -20,23 +21,33 @@ namespace { namespace ndarray { namespace broadcast {
|
||||
```
|
||||
|
||||
Other interesting examples to consider:
|
||||
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
|
||||
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) ... ok`
|
||||
- `can_broadcast_shape_to([3], [3, 1]) == false`
|
||||
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
|
||||
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) ... ok`
|
||||
|
||||
In cases when the shapes contain zero(es):
|
||||
- `can_broadcast_shape_to([0], [1]) == true`
|
||||
- `can_broadcast_shape_to([0], [1]) ... ok`
|
||||
- `can_broadcast_shape_to([0], [2]) == false`
|
||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
|
||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
|
||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
|
||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) ... ok`
|
||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) ... ok`
|
||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) ... ok`
|
||||
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
|
||||
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
|
||||
*/
|
||||
|
||||
// This is essentially doing the following in Python:
|
||||
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
|
||||
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
|
||||
// Target ndims must not be smaller than source ndims
|
||||
// e.g., `np.broadcast_to(np.zeros((1, 1, 1, 1)), (1, ))` is prohibited by numpy
|
||||
if (target_ndims < src_ndims) {
|
||||
// Error copied from python by doing the `np.broadcast_to(np.zeros((1, 1, 1, 1)), (1, ))`
|
||||
errctx->set_error(
|
||||
errctx->error_ids->value_error,
|
||||
"input operand has more dimensions than allowed by the axis remapping"
|
||||
);
|
||||
return; // Terminate
|
||||
}
|
||||
|
||||
// Implements the rules in https://numpy.org/doc/stable/user/basics.broadcasting.html
|
||||
for (SizeT i = 0; i < src_ndims; i++) {
|
||||
SizeT target_axis = target_ndims - i - 1;
|
||||
SizeT src_axis = src_ndims - i - 1;
|
||||
|
||||
@ -47,10 +58,18 @@ namespace { namespace ndarray { namespace broadcast {
|
||||
SizeT src_dim = src_dim_exists ? src_shape[src_axis] : 1;
|
||||
|
||||
bool ok = src_dim == 1 || target_dim == src_dim;
|
||||
if (!ok) return false;
|
||||
if (!ok) {
|
||||
// Error copied from python by doing `np.broadcast_to(np.zeros((3, 1)), (1, 1)),
|
||||
// but this is the true numpy error:
|
||||
// "ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (3,1) and requested shape (1,1)"
|
||||
// TODO: we cannot show more than 3 parameters!!
|
||||
errctx->set_error(
|
||||
errctx->error_ids->value_error,
|
||||
"operands could not be broadcast together with remapping shapes [original->remapped]"
|
||||
);
|
||||
return; // Terminate
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,18 +98,20 @@ namespace { namespace ndarray { namespace broadcast {
|
||||
// # This implementation will NOT support this assignment.
|
||||
// ```
|
||||
template <typename SizeT>
|
||||
void broadcast_to(NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
void broadcast_to(ErrorContext* errctx, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
dst_ndarray->data = src_ndarray->data;
|
||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||
|
||||
// irrt_assert(
|
||||
// ndarray_util::can_broadcast_shape_to(
|
||||
// dst_ndarray->ndims,
|
||||
// dst_ndarray->shape,
|
||||
// src_ndarray->ndims,
|
||||
// src_ndarray->shape
|
||||
// )
|
||||
// );
|
||||
ndarray::broadcast::util::assert_broadcast_shape_to(
|
||||
errctx,
|
||||
dst_ndarray->ndims,
|
||||
dst_ndarray->shape,
|
||||
src_ndarray->ndims,
|
||||
src_ndarray->shape
|
||||
);
|
||||
if (errctx->has_error()) {
|
||||
return; // Propagate error
|
||||
}
|
||||
|
||||
SizeT stride_product = 1;
|
||||
for (SizeT i = 0; i < max(src_ndarray->ndims, dst_ndarray->ndims); i++) {
|
||||
|
@ -6,8 +6,6 @@
|
||||
#include <irrt/error_context.hpp>
|
||||
|
||||
namespace {
|
||||
typedef uint32_t NumNDSubscriptsType;
|
||||
|
||||
typedef uint8_t NDSubscriptType;
|
||||
|
||||
const NDSubscriptType INPUT_SUBSCRIPT_TYPE_INDEX = 0;
|
||||
@ -72,7 +70,7 @@ namespace { namespace ndarray { namespace subscript {
|
||||
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->itemsize`
|
||||
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
|
||||
template <typename SizeT>
|
||||
void subscript(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
void subscript(ErrorContext* errctx, SliceIndex num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||
// REFERENCE CODE (check out `_index_helper` in `__getitem__`):
|
||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
|
||||
|
||||
@ -84,7 +82,7 @@ namespace { namespace ndarray { namespace subscript {
|
||||
SizeT src_axis = 0;
|
||||
SizeT dst_axis = 0;
|
||||
|
||||
for (SizeT i = 0; i < num_subscripts; i++) {
|
||||
for (SliceIndex i = 0; i < num_subscripts; i++) {
|
||||
NDSubscript *ndsubscript = &subscripts[i];
|
||||
if (ndsubscript->type == INPUT_SUBSCRIPT_TYPE_INDEX) {
|
||||
// Handle when the ndsubscript is just a single (possibly negative) integer
|
||||
@ -161,11 +159,11 @@ extern "C" {
|
||||
ndarray::subscript::util::deduce_ndims_after_slicing(errctx, result, ndims, num_ndsubscripts, ndsubscripts);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_subscript(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) {
|
||||
void __nac3_ndarray_subscript(ErrorContext* errctx, SliceIndex num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) {
|
||||
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
||||
}
|
||||
|
||||
void __nac3_ndarray_subscript64(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) {
|
||||
void __nac3_ndarray_subscript64(ErrorContext* errctx, SliceIndex num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) {
|
||||
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
||||
}
|
||||
}
|
@ -11,6 +11,7 @@
|
||||
#include <test/test_core.hpp>
|
||||
#include <test/test_ndarray_basic.hpp>
|
||||
#include <test/test_ndarray_subscript.hpp>
|
||||
#include <test/test_ndarray_broadcast.hpp>
|
||||
#include <test/test_slice.hpp>
|
||||
|
||||
int main() {
|
||||
@ -19,5 +20,6 @@ int main() {
|
||||
test::slice::run();
|
||||
test::ndarray_basic::run();
|
||||
test::ndarray_subscript::run();
|
||||
test::ndarray_broadcast::run();
|
||||
return 0;
|
||||
}
|
72
nac3core/irrt/test/test_ndarray_broadcast.hpp
Normal file
72
nac3core/irrt/test/test_ndarray_broadcast.hpp
Normal file
@ -0,0 +1,72 @@
|
||||
#pragma once
|
||||
|
||||
#include <test/core.hpp>
|
||||
#include <irrt_everything.hpp>
|
||||
|
||||
namespace test { namespace ndarray_broadcast {
|
||||
void test_ndarray_broadcast_1() {
|
||||
/*
|
||||
```python
|
||||
array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
|
||||
>>> [[19.9 29.9 39.9 49.9]]
|
||||
|
||||
array = np.broadcast_to(array, (2, 3, 4))
|
||||
>>> [[[19.9 29.9 39.9 49.9]
|
||||
>>> [19.9 29.9 39.9 49.9]
|
||||
>>> [19.9 29.9 39.9 49.9]]
|
||||
>>> [[19.9 29.9 39.9 49.9]
|
||||
>>> [19.9 29.9 39.9 49.9]
|
||||
>>> [19.9 29.9 39.9 49.9]]]
|
||||
|
||||
assert array.strides == (0, 0, 8)
|
||||
# and then pick some values in `array` and check them...
|
||||
```
|
||||
*/
|
||||
BEGIN_TEST();
|
||||
|
||||
// Prepare src_ndarray
|
||||
double src_data[4] = { 19.9, 29.9, 39.9, 49.9 };
|
||||
const int32_t src_ndims = 2;
|
||||
int32_t src_shape[src_ndims] = {1, 4};
|
||||
int32_t src_strides[src_ndims] = {};
|
||||
NDArray<int32_t> src_ndarray = {
|
||||
.data = (uint8_t*) src_data,
|
||||
.itemsize = sizeof(double),
|
||||
.ndims = src_ndims,
|
||||
.shape = src_shape,
|
||||
.strides = src_strides
|
||||
};
|
||||
ndarray::basic::set_strides_by_shape(&src_ndarray);
|
||||
|
||||
// Prepare dst_ndarray
|
||||
const int32_t dst_ndims = 3;
|
||||
int32_t dst_shape[dst_ndims] = {2, 3, 4};
|
||||
int32_t dst_strides[dst_ndims] = {};
|
||||
NDArray<int32_t> dst_ndarray = {
|
||||
.ndims = dst_ndims,
|
||||
.shape = dst_shape,
|
||||
.strides = dst_strides
|
||||
};
|
||||
|
||||
// Broadcast
|
||||
ErrorContext errctx = create_testing_errctx();
|
||||
ndarray::broadcast::broadcast_to(&errctx, &src_ndarray, &dst_ndarray);
|
||||
assert_errctx_no_error(&errctx);
|
||||
|
||||
assert_arrays_match(dst_ndims, ((int32_t[]) { 0, 0, 8 }), dst_ndarray.strides);
|
||||
|
||||
assert_values_match(19.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 0}))));
|
||||
assert_values_match(29.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 1}))));
|
||||
assert_values_match(39.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 2}))));
|
||||
assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 3}))));
|
||||
assert_values_match(19.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 0}))));
|
||||
assert_values_match(29.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 1}))));
|
||||
assert_values_match(39.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 2}))));
|
||||
assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 3}))));
|
||||
assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {1, 2, 3}))));
|
||||
}
|
||||
|
||||
void run() {
|
||||
test_ndarray_broadcast_1();
|
||||
}
|
||||
}}
|
@ -2189,7 +2189,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||
Ok(match value_expr {
|
||||
None => None,
|
||||
Some(value_expr) => Some(
|
||||
slice_index_model.check_llvm_value(
|
||||
slice_index_model.review(
|
||||
generator
|
||||
.gen_expr(ctx, value_expr)?
|
||||
.unwrap()
|
||||
@ -2209,7 +2209,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||
// Anything else that is not a slice (might be illegal values),
|
||||
// For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error
|
||||
|
||||
let index = slice_index_model.check_llvm_value(
|
||||
let index = slice_index_model.review(
|
||||
generator
|
||||
.gen_expr(ctx, subscript_expr)?
|
||||
.unwrap()
|
||||
@ -2931,7 +2931,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
let ndarray_ptr_model = PointerModel(StructModel(NpArray { sizet }));
|
||||
|
||||
let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
|
||||
ndarray_ptr_model.check_llvm_value(v.as_any_value_enum())
|
||||
ndarray_ptr_model.review(v.as_any_value_enum())
|
||||
} else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
@ -156,7 +156,7 @@ pub fn call_nac3_ndarray_subscript_deduce_ndims_after_slicing<'ctx, G: CodeGener
|
||||
pub fn call_nac3_ndarray_subscript<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
num_subscripts: FixedInt<'ctx, Int32>,
|
||||
num_subscripts: SliceIndex<'ctx>,
|
||||
subscripts: Pointer<'ctx, StructModel<NDSubscript>>,
|
||||
src_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
dst_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||
@ -171,7 +171,7 @@ pub fn call_nac3_ndarray_subscript<'ctx, G: CodeGenerator + ?Sized>(
|
||||
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_subscript"),
|
||||
)
|
||||
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx_ptr)
|
||||
.arg("num_subscripts", FixedIntModel(Int32), num_subscripts)
|
||||
.arg("num_subscripts", SliceIndexModel::default(), num_subscripts)
|
||||
.arg("subscripts", PointerModel(StructModel(NDSubscript)), subscripts)
|
||||
.arg("src_ndarray", PointerModel(StructModel(NpArray { sizet })), src_ndarray)
|
||||
.arg("dst_ndarray", PointerModel(StructModel(NpArray { sizet })), dst_ndarray)
|
||||
|
@ -61,7 +61,7 @@ impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> {
|
||||
});
|
||||
|
||||
let ret = self.ctx.builder.build_call(function, ¶m_vals, name).unwrap();
|
||||
return_model.check_llvm_value(ret.as_any_value_enum())
|
||||
return_model.review(ret.as_any_value_enum())
|
||||
}
|
||||
|
||||
// TODO: Code duplication, but otherwise returning<S: Optic<'ctx>> cannot resolve S if return_optic = None
|
||||
|
@ -22,11 +22,18 @@ pub trait ModelValue<'ctx>: Clone + Copy {
|
||||
fn get_llvm_value(&self) -> BasicValueEnum<'ctx>;
|
||||
}
|
||||
|
||||
pub trait Model<'ctx>: Clone + Copy {
|
||||
// Should have been within [`Model<ctx>`],
|
||||
// but rust object safety requirements made it necessary to
|
||||
// split this interface out
|
||||
pub trait CanCheckLLVMType {
|
||||
fn check_llvm_type<'ctx>(&self, ctx: &'ctx Context) -> Result<(), String>;
|
||||
}
|
||||
|
||||
pub trait Model<'ctx>: Clone + Copy + CanCheckLLVMType + Sized {
|
||||
type Value: ModelValue<'ctx>;
|
||||
|
||||
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
|
||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value;
|
||||
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value;
|
||||
|
||||
fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Pointer<'ctx, Self> {
|
||||
Pointer {
|
||||
|
@ -7,7 +7,7 @@ use itertools::Itertools;
|
||||
|
||||
use crate::codegen::CodeGenContext;
|
||||
|
||||
use super::{Model, ModelValue, Pointer};
|
||||
use super::{core::CanCheckLLVMType, Model, ModelValue, Pointer};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Field<E> {
|
||||
@ -17,14 +17,12 @@ pub struct Field<E> {
|
||||
}
|
||||
|
||||
// Like [`Field<E>`] but element must be [`BasicTypeEnum<'ctx>`]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct FieldLLVM<'ctx> {
|
||||
gep_index: u64,
|
||||
name: &'ctx str,
|
||||
llvm_type: BasicTypeEnum<'ctx>,
|
||||
llvm_type: Box<dyn CanCheckLLVMType>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FieldBuilder<'ctx> {
|
||||
pub ctx: &'ctx Context,
|
||||
gep_index_counter: u64,
|
||||
@ -57,6 +55,33 @@ impl<'ctx> FieldBuilder<'ctx> {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> Result<(), String>
|
||||
where
|
||||
A: BasicType<'ctx>,
|
||||
B: BasicType<'ctx>,
|
||||
{
|
||||
let expected = expected.as_basic_type_enum();
|
||||
let got = got.as_basic_type_enum();
|
||||
|
||||
// Put those logic into here,
|
||||
// otherwise there is always a fallback reporting on any kind of mismatch
|
||||
match (expected, got) {
|
||||
(BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => {
|
||||
if expected.get_bit_width() != got.get_bit_width() {
|
||||
return Err(format!(
|
||||
"Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))"
|
||||
));
|
||||
}
|
||||
}
|
||||
(expected, got) => {
|
||||
if expected != got {
|
||||
return Err(format!("Expected {expected}, got {got}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub trait IsStruct<'ctx>: Clone + Copy {
|
||||
type Fields;
|
||||
|
||||
@ -75,7 +100,12 @@ pub trait IsStruct<'ctx>: Clone + Copy {
|
||||
|
||||
let field_types =
|
||||
builder.fields.iter().map(|field_info| field_info.llvm_type).collect_vec();
|
||||
ctx.struct_type(&field_types, false)
|
||||
ctx.struct_type(&field_types, false).as_basic_type_enum().into_pointer_type().get_el
|
||||
}
|
||||
|
||||
fn check_struct_type(&self) {
|
||||
// Datatypes behind
|
||||
// check_basic_types_match
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,6 +124,12 @@ impl<'ctx, S: IsStruct<'ctx>> ModelValue<'ctx> for Struct<'ctx, S> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, S: IsStruct<'ctx>> CanCheckLLVMType<'ctx> for StructModel<S> {
|
||||
fn check_llvm_type<'ctx>(&self, ctx: &'ctx Context) -> Result<(), String> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel<S> {
|
||||
type Value = Struct<'ctx, S>; // TODO: enrich it
|
||||
|
||||
@ -101,7 +137,7 @@ impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel<S> {
|
||||
self.0.get_struct_type(ctx).as_basic_type_enum()
|
||||
}
|
||||
|
||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
// TODO: check structure
|
||||
Struct { structure: self.0, value: value.into_struct_value() }
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ impl<'ctx> Model<'ctx> for IntModel<'ctx> {
|
||||
self.0.as_basic_type_enum()
|
||||
}
|
||||
|
||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
let int = value.into_int_value();
|
||||
assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width());
|
||||
Int(int)
|
||||
@ -130,7 +130,7 @@ impl<'ctx, T: IsFixedInt> Model<'ctx> for FixedIntModel<T> {
|
||||
T::get_int_type(ctx).as_basic_type_enum()
|
||||
}
|
||||
|
||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
let value = value.into_int_value();
|
||||
assert_eq!(value.get_type().get_bit_width(), T::get_bit_width());
|
||||
FixedInt { int: self.0, value }
|
||||
|
@ -31,7 +31,7 @@ impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> {
|
||||
|
||||
pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> E::Value {
|
||||
let val = ctx.builder.build_load(self.value, name).unwrap();
|
||||
self.element.check_llvm_value(val.as_any_value_enum())
|
||||
self.element.review(val.as_any_value_enum())
|
||||
}
|
||||
|
||||
pub fn to_opaque(self) -> OpaquePointer<'ctx> {
|
||||
@ -66,7 +66,7 @@ impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel<E> {
|
||||
self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum()
|
||||
}
|
||||
|
||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
// TODO: Check get_element_type()? for LLVM 14 at least...
|
||||
Pointer { element: self.0, value: value.into_pointer_value() }
|
||||
}
|
||||
@ -92,7 +92,7 @@ impl<'ctx> Model<'ctx> for OpaquePointerModel {
|
||||
ctx.i8_type().ptr_type(AddressSpace::default()).as_basic_type_enum()
|
||||
}
|
||||
|
||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||
let ptr = value.into_pointer_value();
|
||||
// TODO: remove this check once LLVM pointers do not have `get_element_type()`
|
||||
assert_eq!(ptr.get_type().get_element_type().into_int_type().get_bit_width(), 8);
|
||||
|
@ -189,10 +189,6 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||
v.data().ptr_offset(ctx, generator, &index, name)
|
||||
}
|
||||
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
todo!()
|
||||
}
|
||||
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
@ -207,89 +203,147 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||
target: &Expr<Option<Type>>,
|
||||
value: ValueEnum<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
/*
|
||||
To handle assignment statements `target = value`, with
|
||||
special care taken for targets `gen_store_target` cannot handle, these are:
|
||||
- Case 1. target is a Tuple
|
||||
- e.g., `(x, y, z, w) = value`
|
||||
- Case 2. *Sliced* list assignment `list.__setitem__`
|
||||
- e.g., `my_list[1:3] = [100, 101]`, BUT NOT `my_list[0] = 99` (gen_store_target knows how to handle these),
|
||||
- Case 3. Indexed ndarray assignment `ndarray.__setitem__`
|
||||
- e.g., `my_ndarray[::-1, :] = 3`, `my_ndarray[:, 3::-1] = their_ndarray[10::2]`
|
||||
- NOTE: Technically speaking, if `target` is sliced in such as way that it is referencing a
|
||||
single element/scalar, we *could* implement gen_store_target for this special case;
|
||||
but it is much, *much* simpler to generalize all indexed ndarray assignment without
|
||||
special handling on that edgecase.
|
||||
- Otherwise, use `gen_store_target`
|
||||
*/
|
||||
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
match &target.node {
|
||||
ExprKind::Tuple { elts, .. } => {
|
||||
let BasicValueEnum::StructValue(v) =
|
||||
value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
if let ExprKind::Tuple { elts, .. } = &target.node {
|
||||
// Handle Case 1. target is a Tuple
|
||||
let BasicValueEnum::StructValue(v) =
|
||||
value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
for (i, elt) in elts.iter().enumerate() {
|
||||
let v = ctx
|
||||
.builder
|
||||
.build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem")
|
||||
.unwrap();
|
||||
generator.gen_assign(ctx, elt, v.into())?;
|
||||
for (i, elt) in elts.iter().enumerate() {
|
||||
let v = ctx
|
||||
.builder
|
||||
.build_extract_value(v, u32::try_from(i).unwrap(), "struct_elem")
|
||||
.unwrap();
|
||||
generator.gen_assign(ctx, elt, v.into())?;
|
||||
}
|
||||
|
||||
return Ok(()); // Terminate
|
||||
}
|
||||
|
||||
// Else, try checking if it's Case 2 or 3, and they *ONLY*
|
||||
// happen if `target.node` is a `ExprKind::Subscript`, so do a special check
|
||||
if let ExprKind::Subscript { value: target_without_slice, slice, .. } = &target.node {
|
||||
// Get the type of target
|
||||
let target_ty = target.custom.unwrap();
|
||||
let target_ty_enum = &*ctx.unifier.get_ty(target_ty);
|
||||
|
||||
// Pattern match on this pair.
|
||||
// This is done like this because of Case 2 - slice.node has to be in a specific pattern
|
||||
match (target_ty_enum, &slice.node) {
|
||||
(TypeEnum::TObj { obj_id, .. }, ExprKind::Slice { lower, upper, step })
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// Case 2. *Sliced* list assignment
|
||||
|
||||
let ls = generator
|
||||
.gen_expr(ctx, target_without_slice)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, target_without_slice.custom.unwrap())?
|
||||
.into_pointer_value();
|
||||
let ls = ListValue::from_ptr_val(ls, llvm_usize, None);
|
||||
let Some((start, end, step)) = handle_slice_indices(
|
||||
lower,
|
||||
upper,
|
||||
step,
|
||||
ctx,
|
||||
generator,
|
||||
ls.load_size(ctx, None),
|
||||
)?
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
let value = value
|
||||
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
||||
.into_pointer_value();
|
||||
let value = ListValue::from_ptr_val(value, llvm_usize, None);
|
||||
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
|
||||
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
||||
*params.iter().next().unwrap().1
|
||||
}
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let ty = ctx.get_llvm_type(generator, ty);
|
||||
let Some(src_ind) = handle_slice_indices(
|
||||
&None,
|
||||
&None,
|
||||
&None,
|
||||
ctx,
|
||||
generator,
|
||||
value.load_size(ctx, None),
|
||||
)?
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
|
||||
|
||||
return Ok(()); // Terminate
|
||||
}
|
||||
(TypeEnum::TObj { obj_id, .. }, _)
|
||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// Case 3. Indexed ndarray assignment
|
||||
|
||||
let target = generator.gen_expr(ctx, target)?.unwrap().to_basic_value_enum(
|
||||
ctx,
|
||||
generator,
|
||||
target.custom.unwrap(),
|
||||
);
|
||||
|
||||
// let value = value.to_basic_value_enum(ctx, generator, value);
|
||||
|
||||
todo!();
|
||||
|
||||
return Ok(()); // Terminate
|
||||
}
|
||||
_ => {
|
||||
// Fallthrough
|
||||
}
|
||||
}
|
||||
ExprKind::Subscript { value: ls, slice, .. }
|
||||
if matches!(&slice.node, ExprKind::Slice { .. }) =>
|
||||
{
|
||||
let ExprKind::Slice { lower, upper, step } = &slice.node else { unreachable!() };
|
||||
}
|
||||
|
||||
let ls = generator
|
||||
.gen_expr(ctx, ls)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ls.custom.unwrap())?
|
||||
.into_pointer_value();
|
||||
let ls = ListValue::from_ptr_val(ls, llvm_usize, None);
|
||||
let Some((start, end, step)) =
|
||||
handle_slice_indices(lower, upper, step, ctx, generator, ls.load_size(ctx, None))?
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
let value = value
|
||||
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
||||
.into_pointer_value();
|
||||
let value = ListValue::from_ptr_val(value, llvm_usize, None);
|
||||
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
|
||||
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
||||
*params.iter().next().unwrap().1
|
||||
}
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let ty = ctx.get_llvm_type(generator, ty);
|
||||
let Some(src_ind) = handle_slice_indices(
|
||||
&None,
|
||||
&None,
|
||||
&None,
|
||||
ctx,
|
||||
generator,
|
||||
value.load_size(ctx, None),
|
||||
)?
|
||||
else {
|
||||
return Ok(());
|
||||
};
|
||||
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
|
||||
}
|
||||
_ => {
|
||||
let name = if let ExprKind::Name { id, .. } = &target.node {
|
||||
format!("{id}.addr")
|
||||
} else {
|
||||
String::from("target.addr")
|
||||
};
|
||||
let Some(ptr) = generator.gen_store_target(ctx, target, Some(name.as_str()))? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if let ExprKind::Name { id, .. } = &target.node {
|
||||
let (_, static_value, counter) = ctx.var_assignment.get_mut(id).unwrap();
|
||||
*counter += 1;
|
||||
if let ValueEnum::Static(s) = &value {
|
||||
*static_value = Some(s.clone());
|
||||
}
|
||||
}
|
||||
let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?;
|
||||
ctx.builder.build_store(ptr, val).unwrap();
|
||||
}
|
||||
// None of the cases match. We should actually use `gen_store_target`.
|
||||
let name = if let ExprKind::Name { id, .. } = &target.node {
|
||||
format!("{id}.addr")
|
||||
} else {
|
||||
String::from("target.addr")
|
||||
};
|
||||
let Some(ptr) = generator.gen_store_target(ctx, target, Some(name.as_str()))? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
if let ExprKind::Name { id, .. } = &target.node {
|
||||
let (_, static_value, counter) = ctx.var_assignment.get_mut(id).unwrap();
|
||||
*counter += 1;
|
||||
if let ValueEnum::Static(s) = &value {
|
||||
*static_value = Some(s.clone());
|
||||
}
|
||||
}
|
||||
let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?;
|
||||
ctx.builder.build_store(ptr, val).unwrap();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -1470,7 +1470,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
let ndarray_ptr_model =
|
||||
PointerModel(StructModel(NpArray { sizet }));
|
||||
let ndarray_ptr =
|
||||
ndarray_ptr_model.check_llvm_value(arg.as_any_value_enum());
|
||||
ndarray_ptr_model.review(arg.as_any_value_enum());
|
||||
|
||||
// Calculate len
|
||||
// NOTE: Unsized object is asserted in IRRT
|
||||
|
Loading…
Reference in New Issue
Block a user