forked from M-Labs/nac3
WIP
This commit is contained in:
parent
0946bd86ea
commit
d90604b713
|
@ -126,6 +126,54 @@ namespace { namespace ndarray { namespace basic {
|
||||||
|
|
||||||
*dst_length = (SliceIndex) ndarray->shape[0];
|
*dst_length = (SliceIndex) ndarray->shape[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Copy data from one ndarray to another *OF THE EXACT SAME* ndims, shape, and itemsize.
|
||||||
|
template <typename SizeT>
|
||||||
|
void copy_data(const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||||
|
__builtin_assume(src_ndarray->ndims == dst_ndarray->ndims);
|
||||||
|
__builtin_assume(src_ndarray->itemsize == dst_ndarray->itemsize);
|
||||||
|
|
||||||
|
for (SizeT i = 0; i < src_ndarray->size; i++) {
|
||||||
|
auto src_element = ndarray::basic::get_nth_pelement(src_ndarray, i);
|
||||||
|
auto dst_element = ndarray::basic::get_nth_pelement(dst_ndarray, i);
|
||||||
|
ndarray::basic::set_pelement_value(dst_ndarray, dst_element, src_element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// `copy_data()` with assertions to check ndims, shape, and itemsize between the two ndarrays.
|
||||||
|
template <typename SizeT>
|
||||||
|
void copy_data_checked(ErrorContext* errctx, const NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||||
|
// NOTE: Out of all error types, runtime error seems appropriate
|
||||||
|
|
||||||
|
// Check ndims
|
||||||
|
if (src_ndarray->ndims != dst_ndarray->ndims) {
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->runtime_error,
|
||||||
|
"IRRT copy_data_checked input arrays `ndims` are mismatched"
|
||||||
|
);
|
||||||
|
return; // Terminate
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check shape
|
||||||
|
if (!arrays_match(src_ndarray->ndims, src_ndarray->shape, dst_ndarray->shape)) {
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->runtime_error,
|
||||||
|
"IRRT copy_data_checked input arrays `shape` are mismatched"
|
||||||
|
);
|
||||||
|
return; // Terminate
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check itemsize
|
||||||
|
if (src_ndarray->itemsize != dst_ndarray->itemsize) {
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->runtime_error,
|
||||||
|
"IRRT copy_data_checked input arrays `itemsize` are mismatched"
|
||||||
|
);
|
||||||
|
return; // Terminate
|
||||||
|
}
|
||||||
|
|
||||||
|
copy_data(src_ndarray, dst_ndarray);
|
||||||
|
}
|
||||||
} } }
|
} } }
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
|
@ -3,7 +3,8 @@
|
||||||
namespace { namespace ndarray { namespace broadcast {
|
namespace { namespace ndarray { namespace broadcast {
|
||||||
namespace util {
|
namespace util {
|
||||||
template <typename SizeT>
|
template <typename SizeT>
|
||||||
bool can_broadcast_shape_to(
|
void assert_broadcast_shape_to(
|
||||||
|
ErrorContext* errctx,
|
||||||
const SizeT target_ndims,
|
const SizeT target_ndims,
|
||||||
const SizeT* target_shape,
|
const SizeT* target_shape,
|
||||||
const SizeT src_ndims,
|
const SizeT src_ndims,
|
||||||
|
@ -20,23 +21,33 @@ namespace { namespace ndarray { namespace broadcast {
|
||||||
```
|
```
|
||||||
|
|
||||||
Other interesting examples to consider:
|
Other interesting examples to consider:
|
||||||
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) == true`
|
- `can_broadcast_shape_to([3], [1, 1, 1, 1, 3]) ... ok`
|
||||||
- `can_broadcast_shape_to([3], [3, 1]) == false`
|
- `can_broadcast_shape_to([3], [3, 1]) == false`
|
||||||
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) == true`
|
- `can_broadcast_shape_to([256, 256, 3], [256, 1, 3]) ... ok`
|
||||||
|
|
||||||
In cases when the shapes contain zero(es):
|
In cases when the shapes contain zero(es):
|
||||||
- `can_broadcast_shape_to([0], [1]) == true`
|
- `can_broadcast_shape_to([0], [1]) ... ok`
|
||||||
- `can_broadcast_shape_to([0], [2]) == false`
|
- `can_broadcast_shape_to([0], [2]) == false`
|
||||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) == true`
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1]) ... ok`
|
||||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) == true`
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 1, 1, 1]) ... ok`
|
||||||
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) == true`
|
- `can_broadcast_shape_to([0, 4, 0, 0], [1, 4, 1, 1]) ... ok`
|
||||||
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
|
- `can_broadcast_shape_to([4, 3], [0, 3]) == false`
|
||||||
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
|
- `can_broadcast_shape_to([4, 3], [0, 0]) == false`
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// This is essentially doing the following in Python:
|
// Target ndims must not be smaller than source ndims
|
||||||
// `for target_dim, src_dim in itertools.zip_longest(target_shape[::-1], src_shape[::-1], fillvalue=1)`
|
// e.g., `np.broadcast_to(np.zeros((1, 1, 1, 1)), (1, ))` is prohibited by numpy
|
||||||
for (SizeT i = 0; i < max(target_ndims, src_ndims); i++) {
|
if (target_ndims < src_ndims) {
|
||||||
|
// Error copied from python by doing the `np.broadcast_to(np.zeros((1, 1, 1, 1)), (1, ))`
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->value_error,
|
||||||
|
"input operand has more dimensions than allowed by the axis remapping"
|
||||||
|
);
|
||||||
|
return; // Terminate
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the rules in https://numpy.org/doc/stable/user/basics.broadcasting.html
|
||||||
|
for (SizeT i = 0; i < src_ndims; i++) {
|
||||||
SizeT target_axis = target_ndims - i - 1;
|
SizeT target_axis = target_ndims - i - 1;
|
||||||
SizeT src_axis = src_ndims - i - 1;
|
SizeT src_axis = src_ndims - i - 1;
|
||||||
|
|
||||||
|
@ -47,10 +58,18 @@ namespace { namespace ndarray { namespace broadcast {
|
||||||
SizeT src_dim = src_dim_exists ? src_shape[src_axis] : 1;
|
SizeT src_dim = src_dim_exists ? src_shape[src_axis] : 1;
|
||||||
|
|
||||||
bool ok = src_dim == 1 || target_dim == src_dim;
|
bool ok = src_dim == 1 || target_dim == src_dim;
|
||||||
if (!ok) return false;
|
if (!ok) {
|
||||||
|
// Error copied from python by doing `np.broadcast_to(np.zeros((3, 1)), (1, 1)),
|
||||||
|
// but this is the true numpy error:
|
||||||
|
// "ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (3,1) and requested shape (1,1)"
|
||||||
|
// TODO: we cannot show more than 3 parameters!!
|
||||||
|
errctx->set_error(
|
||||||
|
errctx->error_ids->value_error,
|
||||||
|
"operands could not be broadcast together with remapping shapes [original->remapped]"
|
||||||
|
);
|
||||||
|
return; // Terminate
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,18 +98,20 @@ namespace { namespace ndarray { namespace broadcast {
|
||||||
// # This implementation will NOT support this assignment.
|
// # This implementation will NOT support this assignment.
|
||||||
// ```
|
// ```
|
||||||
template <typename SizeT>
|
template <typename SizeT>
|
||||||
void broadcast_to(NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
void broadcast_to(ErrorContext* errctx, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||||
dst_ndarray->data = src_ndarray->data;
|
dst_ndarray->data = src_ndarray->data;
|
||||||
dst_ndarray->itemsize = src_ndarray->itemsize;
|
dst_ndarray->itemsize = src_ndarray->itemsize;
|
||||||
|
|
||||||
// irrt_assert(
|
ndarray::broadcast::util::assert_broadcast_shape_to(
|
||||||
// ndarray_util::can_broadcast_shape_to(
|
errctx,
|
||||||
// dst_ndarray->ndims,
|
dst_ndarray->ndims,
|
||||||
// dst_ndarray->shape,
|
dst_ndarray->shape,
|
||||||
// src_ndarray->ndims,
|
src_ndarray->ndims,
|
||||||
// src_ndarray->shape
|
src_ndarray->shape
|
||||||
// )
|
);
|
||||||
// );
|
if (errctx->has_error()) {
|
||||||
|
return; // Propagate error
|
||||||
|
}
|
||||||
|
|
||||||
SizeT stride_product = 1;
|
SizeT stride_product = 1;
|
||||||
for (SizeT i = 0; i < max(src_ndarray->ndims, dst_ndarray->ndims); i++) {
|
for (SizeT i = 0; i < max(src_ndarray->ndims, dst_ndarray->ndims); i++) {
|
||||||
|
|
|
@ -6,8 +6,6 @@
|
||||||
#include <irrt/error_context.hpp>
|
#include <irrt/error_context.hpp>
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
typedef uint32_t NumNDSubscriptsType;
|
|
||||||
|
|
||||||
typedef uint8_t NDSubscriptType;
|
typedef uint8_t NDSubscriptType;
|
||||||
|
|
||||||
const NDSubscriptType INPUT_SUBSCRIPT_TYPE_INDEX = 0;
|
const NDSubscriptType INPUT_SUBSCRIPT_TYPE_INDEX = 0;
|
||||||
|
@ -72,7 +70,7 @@ namespace { namespace ndarray { namespace subscript {
|
||||||
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->itemsize`
|
// - `dst_ndarray->itemsize` does not have to be set, it will be set to `src_ndarray->itemsize`
|
||||||
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
|
// - `dst_ndarray->shape` and `dst_ndarray.strides` can contain empty values
|
||||||
template <typename SizeT>
|
template <typename SizeT>
|
||||||
void subscript(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
void subscript(ErrorContext* errctx, SliceIndex num_subscripts, NDSubscript* subscripts, NDArray<SizeT>* src_ndarray, NDArray<SizeT>* dst_ndarray) {
|
||||||
// REFERENCE CODE (check out `_index_helper` in `__getitem__`):
|
// REFERENCE CODE (check out `_index_helper` in `__getitem__`):
|
||||||
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
|
// https://github.com/wadetb/tinynumpy/blob/0d23d22e07062ffab2afa287374c7b366eebdda1/tinynumpy/tinynumpy.py#L652
|
||||||
|
|
||||||
|
@ -84,7 +82,7 @@ namespace { namespace ndarray { namespace subscript {
|
||||||
SizeT src_axis = 0;
|
SizeT src_axis = 0;
|
||||||
SizeT dst_axis = 0;
|
SizeT dst_axis = 0;
|
||||||
|
|
||||||
for (SizeT i = 0; i < num_subscripts; i++) {
|
for (SliceIndex i = 0; i < num_subscripts; i++) {
|
||||||
NDSubscript *ndsubscript = &subscripts[i];
|
NDSubscript *ndsubscript = &subscripts[i];
|
||||||
if (ndsubscript->type == INPUT_SUBSCRIPT_TYPE_INDEX) {
|
if (ndsubscript->type == INPUT_SUBSCRIPT_TYPE_INDEX) {
|
||||||
// Handle when the ndsubscript is just a single (possibly negative) integer
|
// Handle when the ndsubscript is just a single (possibly negative) integer
|
||||||
|
@ -161,11 +159,11 @@ extern "C" {
|
||||||
ndarray::subscript::util::deduce_ndims_after_slicing(errctx, result, ndims, num_ndsubscripts, ndsubscripts);
|
ndarray::subscript::util::deduce_ndims_after_slicing(errctx, result, ndims, num_ndsubscripts, ndsubscripts);
|
||||||
}
|
}
|
||||||
|
|
||||||
void __nac3_ndarray_subscript(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) {
|
void __nac3_ndarray_subscript(ErrorContext* errctx, SliceIndex num_subscripts, NDSubscript* subscripts, NDArray<int32_t>* src_ndarray, NDArray<int32_t> *dst_ndarray) {
|
||||||
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
||||||
}
|
}
|
||||||
|
|
||||||
void __nac3_ndarray_subscript64(ErrorContext* errctx, NumNDSubscriptsType num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) {
|
void __nac3_ndarray_subscript64(ErrorContext* errctx, SliceIndex num_subscripts, NDSubscript* subscripts, NDArray<int64_t>* src_ndarray, NDArray<int64_t> *dst_ndarray) {
|
||||||
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
subscript(errctx, num_subscripts, subscripts, src_ndarray, dst_ndarray);
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -11,6 +11,7 @@
|
||||||
#include <test/test_core.hpp>
|
#include <test/test_core.hpp>
|
||||||
#include <test/test_ndarray_basic.hpp>
|
#include <test/test_ndarray_basic.hpp>
|
||||||
#include <test/test_ndarray_subscript.hpp>
|
#include <test/test_ndarray_subscript.hpp>
|
||||||
|
#include <test/test_ndarray_broadcast.hpp>
|
||||||
#include <test/test_slice.hpp>
|
#include <test/test_slice.hpp>
|
||||||
|
|
||||||
int main() {
|
int main() {
|
||||||
|
@ -19,5 +20,6 @@ int main() {
|
||||||
test::slice::run();
|
test::slice::run();
|
||||||
test::ndarray_basic::run();
|
test::ndarray_basic::run();
|
||||||
test::ndarray_subscript::run();
|
test::ndarray_subscript::run();
|
||||||
|
test::ndarray_broadcast::run();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
|
@ -0,0 +1,72 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <test/core.hpp>
|
||||||
|
#include <irrt_everything.hpp>
|
||||||
|
|
||||||
|
namespace test { namespace ndarray_broadcast {
|
||||||
|
void test_ndarray_broadcast_1() {
|
||||||
|
/*
|
||||||
|
```python
|
||||||
|
array = np.array([[19.9, 29.9, 39.9, 49.9]], dtype=np.float64)
|
||||||
|
>>> [[19.9 29.9 39.9 49.9]]
|
||||||
|
|
||||||
|
array = np.broadcast_to(array, (2, 3, 4))
|
||||||
|
>>> [[[19.9 29.9 39.9 49.9]
|
||||||
|
>>> [19.9 29.9 39.9 49.9]
|
||||||
|
>>> [19.9 29.9 39.9 49.9]]
|
||||||
|
>>> [[19.9 29.9 39.9 49.9]
|
||||||
|
>>> [19.9 29.9 39.9 49.9]
|
||||||
|
>>> [19.9 29.9 39.9 49.9]]]
|
||||||
|
|
||||||
|
assert array.strides == (0, 0, 8)
|
||||||
|
# and then pick some values in `array` and check them...
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
BEGIN_TEST();
|
||||||
|
|
||||||
|
// Prepare src_ndarray
|
||||||
|
double src_data[4] = { 19.9, 29.9, 39.9, 49.9 };
|
||||||
|
const int32_t src_ndims = 2;
|
||||||
|
int32_t src_shape[src_ndims] = {1, 4};
|
||||||
|
int32_t src_strides[src_ndims] = {};
|
||||||
|
NDArray<int32_t> src_ndarray = {
|
||||||
|
.data = (uint8_t*) src_data,
|
||||||
|
.itemsize = sizeof(double),
|
||||||
|
.ndims = src_ndims,
|
||||||
|
.shape = src_shape,
|
||||||
|
.strides = src_strides
|
||||||
|
};
|
||||||
|
ndarray::basic::set_strides_by_shape(&src_ndarray);
|
||||||
|
|
||||||
|
// Prepare dst_ndarray
|
||||||
|
const int32_t dst_ndims = 3;
|
||||||
|
int32_t dst_shape[dst_ndims] = {2, 3, 4};
|
||||||
|
int32_t dst_strides[dst_ndims] = {};
|
||||||
|
NDArray<int32_t> dst_ndarray = {
|
||||||
|
.ndims = dst_ndims,
|
||||||
|
.shape = dst_shape,
|
||||||
|
.strides = dst_strides
|
||||||
|
};
|
||||||
|
|
||||||
|
// Broadcast
|
||||||
|
ErrorContext errctx = create_testing_errctx();
|
||||||
|
ndarray::broadcast::broadcast_to(&errctx, &src_ndarray, &dst_ndarray);
|
||||||
|
assert_errctx_no_error(&errctx);
|
||||||
|
|
||||||
|
assert_arrays_match(dst_ndims, ((int32_t[]) { 0, 0, 8 }), dst_ndarray.strides);
|
||||||
|
|
||||||
|
assert_values_match(19.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 0}))));
|
||||||
|
assert_values_match(29.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 1}))));
|
||||||
|
assert_values_match(39.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 2}))));
|
||||||
|
assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 0, 3}))));
|
||||||
|
assert_values_match(19.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 0}))));
|
||||||
|
assert_values_match(29.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 1}))));
|
||||||
|
assert_values_match(39.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 2}))));
|
||||||
|
assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {0, 1, 3}))));
|
||||||
|
assert_values_match(49.9, *((double*) ndarray::basic::get_pelement_by_indices(&dst_ndarray, ((int32_t[]) {1, 2, 3}))));
|
||||||
|
}
|
||||||
|
|
||||||
|
void run() {
|
||||||
|
test_ndarray_broadcast_1();
|
||||||
|
}
|
||||||
|
}}
|
|
@ -2189,7 +2189,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
Ok(match value_expr {
|
Ok(match value_expr {
|
||||||
None => None,
|
None => None,
|
||||||
Some(value_expr) => Some(
|
Some(value_expr) => Some(
|
||||||
slice_index_model.check_llvm_value(
|
slice_index_model.review(
|
||||||
generator
|
generator
|
||||||
.gen_expr(ctx, value_expr)?
|
.gen_expr(ctx, value_expr)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -2209,7 +2209,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
// Anything else that is not a slice (might be illegal values),
|
// Anything else that is not a slice (might be illegal values),
|
||||||
// For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error
|
// For nac3core, this should be e.g., an int32 constant, an int32 variable, otherwise its an error
|
||||||
|
|
||||||
let index = slice_index_model.check_llvm_value(
|
let index = slice_index_model.review(
|
||||||
generator
|
generator
|
||||||
.gen_expr(ctx, subscript_expr)?
|
.gen_expr(ctx, subscript_expr)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -2931,7 +2931,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
let ndarray_ptr_model = PointerModel(StructModel(NpArray { sizet }));
|
let ndarray_ptr_model = PointerModel(StructModel(NpArray { sizet }));
|
||||||
|
|
||||||
let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
|
let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
|
||||||
ndarray_ptr_model.check_llvm_value(v.as_any_value_enum())
|
ndarray_ptr_model.review(v.as_any_value_enum())
|
||||||
} else {
|
} else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
|
|
@ -156,7 +156,7 @@ pub fn call_nac3_ndarray_subscript_deduce_ndims_after_slicing<'ctx, G: CodeGener
|
||||||
pub fn call_nac3_ndarray_subscript<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_nac3_ndarray_subscript<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
num_subscripts: FixedInt<'ctx, Int32>,
|
num_subscripts: SliceIndex<'ctx>,
|
||||||
subscripts: Pointer<'ctx, StructModel<NDSubscript>>,
|
subscripts: Pointer<'ctx, StructModel<NDSubscript>>,
|
||||||
src_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
src_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||||
dst_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
dst_ndarray: Pointer<'ctx, StructModel<NpArray<'ctx>>>,
|
||||||
|
@ -171,7 +171,7 @@ pub fn call_nac3_ndarray_subscript<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_subscript"),
|
&get_sized_dependent_function_name(sizet, "__nac3_ndarray_subscript"),
|
||||||
)
|
)
|
||||||
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx_ptr)
|
.arg("errctx", PointerModel(StructModel(ErrorContext)), errctx_ptr)
|
||||||
.arg("num_subscripts", FixedIntModel(Int32), num_subscripts)
|
.arg("num_subscripts", SliceIndexModel::default(), num_subscripts)
|
||||||
.arg("subscripts", PointerModel(StructModel(NDSubscript)), subscripts)
|
.arg("subscripts", PointerModel(StructModel(NDSubscript)), subscripts)
|
||||||
.arg("src_ndarray", PointerModel(StructModel(NpArray { sizet })), src_ndarray)
|
.arg("src_ndarray", PointerModel(StructModel(NpArray { sizet })), src_ndarray)
|
||||||
.arg("dst_ndarray", PointerModel(StructModel(NpArray { sizet })), dst_ndarray)
|
.arg("dst_ndarray", PointerModel(StructModel(NpArray { sizet })), dst_ndarray)
|
||||||
|
|
|
@ -61,7 +61,7 @@ impl<'ctx, 'a> FunctionBuilder<'ctx, 'a> {
|
||||||
});
|
});
|
||||||
|
|
||||||
let ret = self.ctx.builder.build_call(function, ¶m_vals, name).unwrap();
|
let ret = self.ctx.builder.build_call(function, ¶m_vals, name).unwrap();
|
||||||
return_model.check_llvm_value(ret.as_any_value_enum())
|
return_model.review(ret.as_any_value_enum())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Code duplication, but otherwise returning<S: Optic<'ctx>> cannot resolve S if return_optic = None
|
// TODO: Code duplication, but otherwise returning<S: Optic<'ctx>> cannot resolve S if return_optic = None
|
||||||
|
|
|
@ -22,11 +22,18 @@ pub trait ModelValue<'ctx>: Clone + Copy {
|
||||||
fn get_llvm_value(&self) -> BasicValueEnum<'ctx>;
|
fn get_llvm_value(&self) -> BasicValueEnum<'ctx>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Model<'ctx>: Clone + Copy {
|
// Should have been within [`Model<ctx>`],
|
||||||
|
// but rust object safety requirements made it necessary to
|
||||||
|
// split this interface out
|
||||||
|
pub trait CanCheckLLVMType {
|
||||||
|
fn check_llvm_type<'ctx>(&self, ctx: &'ctx Context) -> Result<(), String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Model<'ctx>: Clone + Copy + CanCheckLLVMType + Sized {
|
||||||
type Value: ModelValue<'ctx>;
|
type Value: ModelValue<'ctx>;
|
||||||
|
|
||||||
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
|
fn get_llvm_type(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
|
||||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value;
|
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value;
|
||||||
|
|
||||||
fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Pointer<'ctx, Self> {
|
fn alloca(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> Pointer<'ctx, Self> {
|
||||||
Pointer {
|
Pointer {
|
||||||
|
|
|
@ -7,7 +7,7 @@ use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::CodeGenContext;
|
use crate::codegen::CodeGenContext;
|
||||||
|
|
||||||
use super::{Model, ModelValue, Pointer};
|
use super::{core::CanCheckLLVMType, Model, ModelValue, Pointer};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct Field<E> {
|
pub struct Field<E> {
|
||||||
|
@ -17,14 +17,12 @@ pub struct Field<E> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Like [`Field<E>`] but element must be [`BasicTypeEnum<'ctx>`]
|
// Like [`Field<E>`] but element must be [`BasicTypeEnum<'ctx>`]
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
struct FieldLLVM<'ctx> {
|
struct FieldLLVM<'ctx> {
|
||||||
gep_index: u64,
|
gep_index: u64,
|
||||||
name: &'ctx str,
|
name: &'ctx str,
|
||||||
llvm_type: BasicTypeEnum<'ctx>,
|
llvm_type: Box<dyn CanCheckLLVMType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct FieldBuilder<'ctx> {
|
pub struct FieldBuilder<'ctx> {
|
||||||
pub ctx: &'ctx Context,
|
pub ctx: &'ctx Context,
|
||||||
gep_index_counter: u64,
|
gep_index_counter: u64,
|
||||||
|
@ -57,6 +55,33 @@ impl<'ctx> FieldBuilder<'ctx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> Result<(), String>
|
||||||
|
where
|
||||||
|
A: BasicType<'ctx>,
|
||||||
|
B: BasicType<'ctx>,
|
||||||
|
{
|
||||||
|
let expected = expected.as_basic_type_enum();
|
||||||
|
let got = got.as_basic_type_enum();
|
||||||
|
|
||||||
|
// Put those logic into here,
|
||||||
|
// otherwise there is always a fallback reporting on any kind of mismatch
|
||||||
|
match (expected, got) {
|
||||||
|
(BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => {
|
||||||
|
if expected.get_bit_width() != got.get_bit_width() {
|
||||||
|
return Err(format!(
|
||||||
|
"Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(expected, got) => {
|
||||||
|
if expected != got {
|
||||||
|
return Err(format!("Expected {expected}, got {got}"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub trait IsStruct<'ctx>: Clone + Copy {
|
pub trait IsStruct<'ctx>: Clone + Copy {
|
||||||
type Fields;
|
type Fields;
|
||||||
|
|
||||||
|
@ -75,7 +100,12 @@ pub trait IsStruct<'ctx>: Clone + Copy {
|
||||||
|
|
||||||
let field_types =
|
let field_types =
|
||||||
builder.fields.iter().map(|field_info| field_info.llvm_type).collect_vec();
|
builder.fields.iter().map(|field_info| field_info.llvm_type).collect_vec();
|
||||||
ctx.struct_type(&field_types, false)
|
ctx.struct_type(&field_types, false).as_basic_type_enum().into_pointer_type().get_el
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_struct_type(&self) {
|
||||||
|
// Datatypes behind
|
||||||
|
// check_basic_types_match
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,6 +124,12 @@ impl<'ctx, S: IsStruct<'ctx>> ModelValue<'ctx> for Struct<'ctx, S> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<'ctx, S: IsStruct<'ctx>> CanCheckLLVMType<'ctx> for StructModel<S> {
|
||||||
|
fn check_llvm_type<'ctx>(&self, ctx: &'ctx Context) -> Result<(), String> {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel<S> {
|
impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel<S> {
|
||||||
type Value = Struct<'ctx, S>; // TODO: enrich it
|
type Value = Struct<'ctx, S>; // TODO: enrich it
|
||||||
|
|
||||||
|
@ -101,7 +137,7 @@ impl<'ctx, S: IsStruct<'ctx>> Model<'ctx> for StructModel<S> {
|
||||||
self.0.get_struct_type(ctx).as_basic_type_enum()
|
self.0.get_struct_type(ctx).as_basic_type_enum()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||||
// TODO: check structure
|
// TODO: check structure
|
||||||
Struct { structure: self.0, value: value.into_struct_value() }
|
Struct { structure: self.0, value: value.into_struct_value() }
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,7 @@ impl<'ctx> Model<'ctx> for IntModel<'ctx> {
|
||||||
self.0.as_basic_type_enum()
|
self.0.as_basic_type_enum()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||||
let int = value.into_int_value();
|
let int = value.into_int_value();
|
||||||
assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width());
|
assert_eq!(int.get_type().get_bit_width(), self.0.get_bit_width());
|
||||||
Int(int)
|
Int(int)
|
||||||
|
@ -130,7 +130,7 @@ impl<'ctx, T: IsFixedInt> Model<'ctx> for FixedIntModel<T> {
|
||||||
T::get_int_type(ctx).as_basic_type_enum()
|
T::get_int_type(ctx).as_basic_type_enum()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||||
let value = value.into_int_value();
|
let value = value.into_int_value();
|
||||||
assert_eq!(value.get_type().get_bit_width(), T::get_bit_width());
|
assert_eq!(value.get_type().get_bit_width(), T::get_bit_width());
|
||||||
FixedInt { int: self.0, value }
|
FixedInt { int: self.0, value }
|
||||||
|
|
|
@ -31,7 +31,7 @@ impl<'ctx, E: Model<'ctx>> Pointer<'ctx, E> {
|
||||||
|
|
||||||
pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> E::Value {
|
pub fn load(&self, ctx: &CodeGenContext<'ctx, '_>, name: &str) -> E::Value {
|
||||||
let val = ctx.builder.build_load(self.value, name).unwrap();
|
let val = ctx.builder.build_load(self.value, name).unwrap();
|
||||||
self.element.check_llvm_value(val.as_any_value_enum())
|
self.element.review(val.as_any_value_enum())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_opaque(self) -> OpaquePointer<'ctx> {
|
pub fn to_opaque(self) -> OpaquePointer<'ctx> {
|
||||||
|
@ -66,7 +66,7 @@ impl<'ctx, E: Model<'ctx>> Model<'ctx> for PointerModel<E> {
|
||||||
self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum()
|
self.0.get_llvm_type(ctx).ptr_type(AddressSpace::default()).as_basic_type_enum()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||||
// TODO: Check get_element_type()? for LLVM 14 at least...
|
// TODO: Check get_element_type()? for LLVM 14 at least...
|
||||||
Pointer { element: self.0, value: value.into_pointer_value() }
|
Pointer { element: self.0, value: value.into_pointer_value() }
|
||||||
}
|
}
|
||||||
|
@ -92,7 +92,7 @@ impl<'ctx> Model<'ctx> for OpaquePointerModel {
|
||||||
ctx.i8_type().ptr_type(AddressSpace::default()).as_basic_type_enum()
|
ctx.i8_type().ptr_type(AddressSpace::default()).as_basic_type_enum()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_llvm_value(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
fn review(&self, value: AnyValueEnum<'ctx>) -> Self::Value {
|
||||||
let ptr = value.into_pointer_value();
|
let ptr = value.into_pointer_value();
|
||||||
// TODO: remove this check once LLVM pointers do not have `get_element_type()`
|
// TODO: remove this check once LLVM pointers do not have `get_element_type()`
|
||||||
assert_eq!(ptr.get_type().get_element_type().into_int_type().get_bit_width(), 8);
|
assert_eq!(ptr.get_type().get_element_type().into_int_type().get_bit_width(), 8);
|
||||||
|
|
|
@ -189,10 +189,6 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||||
v.data().ptr_offset(ctx, generator, &index, name)
|
v.data().ptr_offset(ctx, generator, &index, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -207,10 +203,26 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||||
target: &Expr<Option<Type>>,
|
target: &Expr<Option<Type>>,
|
||||||
value: ValueEnum<'ctx>,
|
value: ValueEnum<'ctx>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
|
/*
|
||||||
|
To handle assignment statements `target = value`, with
|
||||||
|
special care taken for targets `gen_store_target` cannot handle, these are:
|
||||||
|
- Case 1. target is a Tuple
|
||||||
|
- e.g., `(x, y, z, w) = value`
|
||||||
|
- Case 2. *Sliced* list assignment `list.__setitem__`
|
||||||
|
- e.g., `my_list[1:3] = [100, 101]`, BUT NOT `my_list[0] = 99` (gen_store_target knows how to handle these),
|
||||||
|
- Case 3. Indexed ndarray assignment `ndarray.__setitem__`
|
||||||
|
- e.g., `my_ndarray[::-1, :] = 3`, `my_ndarray[:, 3::-1] = their_ndarray[10::2]`
|
||||||
|
- NOTE: Technically speaking, if `target` is sliced in such as way that it is referencing a
|
||||||
|
single element/scalar, we *could* implement gen_store_target for this special case;
|
||||||
|
but it is much, *much* simpler to generalize all indexed ndarray assignment without
|
||||||
|
special handling on that edgecase.
|
||||||
|
- Otherwise, use `gen_store_target`
|
||||||
|
*/
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
match &target.node {
|
if let ExprKind::Tuple { elts, .. } = &target.node {
|
||||||
ExprKind::Tuple { elts, .. } => {
|
// Handle Case 1. target is a Tuple
|
||||||
let BasicValueEnum::StructValue(v) =
|
let BasicValueEnum::StructValue(v) =
|
||||||
value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
||||||
else {
|
else {
|
||||||
|
@ -224,20 +236,39 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||||
.unwrap();
|
.unwrap();
|
||||||
generator.gen_assign(ctx, elt, v.into())?;
|
generator.gen_assign(ctx, elt, v.into())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return Ok(()); // Terminate
|
||||||
}
|
}
|
||||||
ExprKind::Subscript { value: ls, slice, .. }
|
|
||||||
if matches!(&slice.node, ExprKind::Slice { .. }) =>
|
// Else, try checking if it's Case 2 or 3, and they *ONLY*
|
||||||
|
// happen if `target.node` is a `ExprKind::Subscript`, so do a special check
|
||||||
|
if let ExprKind::Subscript { value: target_without_slice, slice, .. } = &target.node {
|
||||||
|
// Get the type of target
|
||||||
|
let target_ty = target.custom.unwrap();
|
||||||
|
let target_ty_enum = &*ctx.unifier.get_ty(target_ty);
|
||||||
|
|
||||||
|
// Pattern match on this pair.
|
||||||
|
// This is done like this because of Case 2 - slice.node has to be in a specific pattern
|
||||||
|
match (target_ty_enum, &slice.node) {
|
||||||
|
(TypeEnum::TObj { obj_id, .. }, ExprKind::Slice { lower, upper, step })
|
||||||
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
let ExprKind::Slice { lower, upper, step } = &slice.node else { unreachable!() };
|
// Case 2. *Sliced* list assignment
|
||||||
|
|
||||||
let ls = generator
|
let ls = generator
|
||||||
.gen_expr(ctx, ls)?
|
.gen_expr(ctx, target_without_slice)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, ls.custom.unwrap())?
|
.to_basic_value_enum(ctx, generator, target_without_slice.custom.unwrap())?
|
||||||
.into_pointer_value();
|
.into_pointer_value();
|
||||||
let ls = ListValue::from_ptr_val(ls, llvm_usize, None);
|
let ls = ListValue::from_ptr_val(ls, llvm_usize, None);
|
||||||
let Some((start, end, step)) =
|
let Some((start, end, step)) = handle_slice_indices(
|
||||||
handle_slice_indices(lower, upper, step, ctx, generator, ls.load_size(ctx, None))?
|
lower,
|
||||||
|
upper,
|
||||||
|
step,
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
ls.load_size(ctx, None),
|
||||||
|
)?
|
||||||
else {
|
else {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
};
|
};
|
||||||
|
@ -268,8 +299,33 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||||
return Ok(());
|
return Ok(());
|
||||||
};
|
};
|
||||||
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
|
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
|
||||||
|
|
||||||
|
return Ok(()); // Terminate
|
||||||
|
}
|
||||||
|
(TypeEnum::TObj { obj_id, .. }, _)
|
||||||
|
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
// Case 3. Indexed ndarray assignment
|
||||||
|
|
||||||
|
let target = generator.gen_expr(ctx, target)?.unwrap().to_basic_value_enum(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
target.custom.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// let value = value.to_basic_value_enum(ctx, generator, value);
|
||||||
|
|
||||||
|
todo!();
|
||||||
|
|
||||||
|
return Ok(()); // Terminate
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
// Fallthrough
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// None of the cases match. We should actually use `gen_store_target`.
|
||||||
let name = if let ExprKind::Name { id, .. } = &target.node {
|
let name = if let ExprKind::Name { id, .. } = &target.node {
|
||||||
format!("{id}.addr")
|
format!("{id}.addr")
|
||||||
} else {
|
} else {
|
||||||
|
@ -288,8 +344,6 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?;
|
let val = value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?;
|
||||||
ctx.builder.build_store(ptr, val).unwrap();
|
ctx.builder.build_store(ptr, val).unwrap();
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1470,7 +1470,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
let ndarray_ptr_model =
|
let ndarray_ptr_model =
|
||||||
PointerModel(StructModel(NpArray { sizet }));
|
PointerModel(StructModel(NpArray { sizet }));
|
||||||
let ndarray_ptr =
|
let ndarray_ptr =
|
||||||
ndarray_ptr_model.check_llvm_value(arg.as_any_value_enum());
|
ndarray_ptr_model.review(arg.as_any_value_enum());
|
||||||
|
|
||||||
// Calculate len
|
// Calculate len
|
||||||
// NOTE: Unsized object is asserted in IRRT
|
// NOTE: Unsized object is asserted in IRRT
|
||||||
|
|
Loading…
Reference in New Issue