forked from M-Labs/nac3
WIP: core: more progress
This commit is contained in:
parent
635542a36d
commit
e75db2c26f
|
@ -173,6 +173,8 @@ namespace {
|
||||||
// NOTE: Formally this should be of type `void *`, but clang
|
// NOTE: Formally this should be of type `void *`, but clang
|
||||||
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
|
// translates `void *` to `i8 *` when run with `-S -emit-llvm`,
|
||||||
// so we will put `uint8_t *` here for clarity.
|
// so we will put `uint8_t *` here for clarity.
|
||||||
|
//
|
||||||
|
// This pointer should point to the first element of the ndarray directly
|
||||||
uint8_t *data;
|
uint8_t *data;
|
||||||
|
|
||||||
// The number of bytes of a single element in `data`.
|
// The number of bytes of a single element in `data`.
|
||||||
|
@ -308,6 +310,7 @@ namespace {
|
||||||
irrt_assert(dst_ndarray->ndims == ndarray_util::deduce_ndims_after_slicing(this->ndims, num_ndslices, ndslices));
|
irrt_assert(dst_ndarray->ndims == ndarray_util::deduce_ndims_after_slicing(this->ndims, num_ndslices, ndslices));
|
||||||
|
|
||||||
dst_ndarray->data = this->data;
|
dst_ndarray->data = this->data;
|
||||||
|
dst_ndarray->itemsize = this->itemsize;
|
||||||
|
|
||||||
SizeT this_axis = 0;
|
SizeT this_axis = 0;
|
||||||
SizeT dst_axis = 0;
|
SizeT dst_axis = 0;
|
||||||
|
@ -346,7 +349,18 @@ namespace {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
irrt_assert(dst_axis == dst_ndarray->ndims); // Sanity check on the implementation
|
/*
|
||||||
|
Reference python code:
|
||||||
|
```python
|
||||||
|
dst_ndarray.shape.extend(this.shape[this_axis:])
|
||||||
|
dst_ndarray.strides.extend(this.strides[this_axis:])
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
|
||||||
|
for (; dst_axis < dst_ndarray->ndims; dst_axis++, this_axis++) {
|
||||||
|
dst_ndarray->shape[dst_axis] = this->shape[this_axis];
|
||||||
|
dst_ndarray->strides[dst_axis] = this->strides[this_axis];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Similar to `np.broadcast_to(<ndarray>, <target_shape>)`
|
// Similar to `np.broadcast_to(<ndarray>, <target_shape>)`
|
||||||
|
@ -435,6 +449,23 @@ namespace {
|
||||||
this->set_pelement_value(this_pelement, src_pelement);
|
this->set_pelement_value(this_pelement, src_pelement);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: DOCUMENT ME
|
||||||
|
bool is_unsized() {
|
||||||
|
return this->ndims == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate `len(<ndarray>)`
|
||||||
|
// See (it doesn't help): https://numpy.org/doc/stable/reference/generated/numpy.ndarray.__len__.html#numpy.ndarray.__len__
|
||||||
|
SliceIndex len() {
|
||||||
|
// If you do `len(np.asarray(42))` (note that its `.shape` is just `()` - an empty tuple),
|
||||||
|
// numpy throws a `TypeError: len() of unsized object`
|
||||||
|
irrt_assert(!this->is_unsized());
|
||||||
|
|
||||||
|
// Apparently `len(<ndarray>)` is defined to be the first dimension
|
||||||
|
// REFERENCE: https://stackoverflow.com/questions/43081809/len-of-a-numpy-array-in-python
|
||||||
|
return (SliceIndex) this->shape[0];
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -478,4 +509,12 @@ extern "C" {
|
||||||
void __nac3_ndarray_subscript64(NDArray<int64_t>* ndarray, int32_t num_slices, NDSlice* slices, NDArray<int64_t> *dst_ndarray) {
|
void __nac3_ndarray_subscript64(NDArray<int64_t>* ndarray, int32_t num_slices, NDSlice* slices, NDArray<int64_t> *dst_ndarray) {
|
||||||
ndarray->subscript(num_slices, slices, dst_ndarray);
|
ndarray->subscript(num_slices, slices, dst_ndarray);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SliceIndex __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
|
||||||
|
return ndarray->len();
|
||||||
|
}
|
||||||
|
|
||||||
|
SliceIndex __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
|
||||||
|
return ndarray->len();
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -475,6 +475,50 @@ void test_ndslice_2() {
|
||||||
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1 })));
|
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1 })));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void test_ndslice_3() {
|
||||||
|
BEGIN_TEST();
|
||||||
|
|
||||||
|
double in_data[12] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
|
||||||
|
const int32_t in_itemsize = sizeof(double);
|
||||||
|
const int32_t in_ndims = 2;
|
||||||
|
int32_t in_shape[in_ndims] = { 3, 4 };
|
||||||
|
int32_t in_strides[in_ndims] = {};
|
||||||
|
NDArray<int32_t> ndarray = {
|
||||||
|
.data = (uint8_t*) in_data,
|
||||||
|
.itemsize = in_itemsize,
|
||||||
|
.ndims = in_ndims,
|
||||||
|
.shape = in_shape,
|
||||||
|
.strides = in_strides
|
||||||
|
};
|
||||||
|
ndarray.set_strides_by_shape();
|
||||||
|
|
||||||
|
const int32_t dst_ndims = 2;
|
||||||
|
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
|
||||||
|
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
|
||||||
|
NDArray<int32_t> dst_ndarray = {
|
||||||
|
.data = nullptr,
|
||||||
|
.ndims = dst_ndims,
|
||||||
|
.shape = dst_shape,
|
||||||
|
.strides = dst_strides
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create the slice in `ndarray[2:3]`
|
||||||
|
UserSlice user_slice_1 = {
|
||||||
|
.start_defined = 1,
|
||||||
|
.start = 2,
|
||||||
|
.stop_defined = 1,
|
||||||
|
.stop = 3,
|
||||||
|
.step_defined = 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
const int32_t num_ndslices = 1;
|
||||||
|
NDSlice ndslices[num_ndslices] = {
|
||||||
|
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_1 },
|
||||||
|
};
|
||||||
|
|
||||||
|
ndarray.subscript(num_ndslices, ndslices, &dst_ndarray);
|
||||||
|
}
|
||||||
|
|
||||||
void test_can_broadcast_shape() {
|
void test_can_broadcast_shape() {
|
||||||
BEGIN_TEST();
|
BEGIN_TEST();
|
||||||
|
|
||||||
|
@ -644,6 +688,7 @@ int main() {
|
||||||
test_slice_4();
|
test_slice_4();
|
||||||
test_ndslice_1();
|
test_ndslice_1();
|
||||||
test_ndslice_2();
|
test_ndslice_2();
|
||||||
|
test_ndslice_3();
|
||||||
test_can_broadcast_shape();
|
test_can_broadcast_shape();
|
||||||
test_ndarray_broadcast_1();
|
test_ndarray_broadcast_1();
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
@ -1979,43 +1979,32 @@ impl<'ctx> NpArrayType<'ctx> {
|
||||||
/// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`.
|
/// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`.
|
||||||
/// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`,
|
/// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`,
|
||||||
/// all with empty/uninitialized values.
|
/// all with empty/uninitialized values.
|
||||||
pub fn var_alloc<G>(
|
pub fn alloca(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
in_ndims: IntValue<'ctx>,
|
in_ndims: IntValue<'ctx>,
|
||||||
name: Option<&str>,
|
name: &str,
|
||||||
) -> NpArrayValue<'ctx>
|
) -> NpArrayValue<'ctx> {
|
||||||
where
|
let ptr = ctx
|
||||||
G: CodeGenerator + ?Sized,
|
.builder
|
||||||
{
|
.build_alloca(self.get_struct_type(ctx.ctx).as_basic_type_enum(), name)
|
||||||
let ptr = generator
|
|
||||||
.gen_var_alloc(ctx, self.get_struct_type(ctx.ctx).as_basic_type_enum(), name)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides`
|
// Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides`
|
||||||
let allocated_shape = generator
|
let allocated_shape = ctx
|
||||||
.gen_array_var_alloc(
|
.builder
|
||||||
ctx,
|
.build_array_alloca(self.size_type.as_basic_type_enum(), in_ndims, "allocated_shape")
|
||||||
self.size_type.as_basic_type_enum(),
|
|
||||||
in_ndims,
|
|
||||||
Some("allocated_shape"),
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let allocated_strides = generator
|
let allocated_strides = ctx
|
||||||
.gen_array_var_alloc(
|
.builder
|
||||||
ctx,
|
.build_array_alloca(self.size_type.as_basic_type_enum(), in_ndims, "allocated_strides")
|
||||||
self.size_type.as_basic_type_enum(),
|
|
||||||
in_ndims,
|
|
||||||
Some("allocated_strides"),
|
|
||||||
)
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let value = NpArrayValue { ty: *self, ptr };
|
let value = NpArrayValue { ty: *self, ptr };
|
||||||
value.store_ndims(ctx, in_ndims);
|
value.store_ndims(ctx, in_ndims);
|
||||||
value.store_itemsize(ctx, self.elem_type.size_of().unwrap());
|
value.store_itemsize(ctx, self.elem_type.size_of().unwrap());
|
||||||
value.store_shape(ctx, allocated_shape.base_ptr(ctx, generator));
|
value.store_shape(ctx, allocated_shape);
|
||||||
value.store_strides(ctx, allocated_strides.base_ptr(ctx, generator));
|
value.store_strides(ctx, allocated_strides);
|
||||||
|
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
@ -2045,6 +2034,11 @@ pub struct NpArrayValue<'ctx> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> NpArrayValue<'ctx> {
|
impl<'ctx> NpArrayValue<'ctx> {
|
||||||
|
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
|
let field = self.ty.fields(ctx.ctx).data;
|
||||||
|
field.load(ctx, self.ptr).into_pointer_value()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, new_data_ptr: PointerValue<'ctx>) {
|
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, new_data_ptr: PointerValue<'ctx>) {
|
||||||
let field = self.ty.fields(ctx.ctx).data;
|
let field = self.ty.fields(ctx.ctx).data;
|
||||||
field.store(ctx, self.ptr, new_data_ptr);
|
field.store(ctx, self.ptr, new_data_ptr);
|
||||||
|
|
|
@ -17,7 +17,8 @@ use crate::{
|
||||||
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
||||||
call_memcpy_generic,
|
call_memcpy_generic,
|
||||||
},
|
},
|
||||||
need_sret, numpy::{self, call_ndarray_subscript_impl},
|
need_sret,
|
||||||
|
numpy::{self, call_ndarray_subscript_impl, get_ndarray_first_element},
|
||||||
stmt::{
|
stmt::{
|
||||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||||
gen_var,
|
gen_var,
|
||||||
|
@ -38,7 +39,7 @@ use crate::{
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
types::{AnyType, BasicType, BasicTypeEnum},
|
types::{AnyType, BasicType, BasicTypeEnum},
|
||||||
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
|
values::{BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
use itertools::{chain, izip, Either, Itertools};
|
use itertools::{chain, izip, Either, Itertools};
|
||||||
|
@ -2100,12 +2101,14 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
|
||||||
/// Generates code for a subscript expression on an `ndarray`.
|
/// Generates code for a subscript expression on an `ndarray`.
|
||||||
///
|
///
|
||||||
/// * `ty` - The `Type` of the `NDArray` elements.
|
/// * `ty` - The `Type` of the `NDArray` elements.
|
||||||
|
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
|
||||||
/// * `ndarray` - The `NDArray` value.
|
/// * `ndarray` - The `NDArray` value.
|
||||||
/// * `slice` - The slice expression used to subscript into the `ndarray`.
|
/// * `slice` - The slice expression used to subscript into the `ndarray`.
|
||||||
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
ty: Type,
|
ty: Type,
|
||||||
|
ndims: Type,
|
||||||
ndarray: NpArrayValue<'ctx>,
|
ndarray: NpArrayValue<'ctx>,
|
||||||
slice: &Expr<Option<Type>>,
|
slice: &Expr<Option<Type>>,
|
||||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||||
|
@ -2165,8 +2168,8 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
let stop = help(stop)?;
|
let stop = help(stop)?;
|
||||||
let step = help(step)?;
|
let step = help(step)?;
|
||||||
|
|
||||||
// NOTE: Now start stop step should all be 32-bit ints after typechecking
|
// start stop step should all be 32-bit ints after typechecking,
|
||||||
// ...and `IrrtUserSlice` expects `int32`s
|
// and `IrrtUserSlice` expects `int32`s
|
||||||
NDSlice::Slice(UserSlice { start, stop, step })
|
NDSlice::Slice(UserSlice { start, stop, step })
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
@ -2185,12 +2188,36 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
ndslices.push(ndslice);
|
ndslices.push(ndslice);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally, perform the actual subscript logic
|
// TODO: what is going on? why the original implementation doesn't assert `ndims_values.len() == 1`
|
||||||
let subndarray = call_ndarray_subscript_impl(generator, ctx, ndarray, &ndslices.iter().collect_vec())?;
|
// Extract the `ndims` from a `Type` to `i128`
|
||||||
|
let TypeEnum::TLiteral { values: ndims_values, .. } = &*ctx.unifier.get_ty_immutable(ndims)
|
||||||
|
else {
|
||||||
|
unreachable!()
|
||||||
|
};
|
||||||
|
assert_eq!(ndims_values.len(), 1);
|
||||||
|
let ndims = i128::try_from(ndims_values[0].clone()).unwrap() as u64;
|
||||||
|
assert!(ndims > 0);
|
||||||
|
|
||||||
// ...and return the result
|
// Deduce the subndarray's ndims
|
||||||
let result = ValueEnum::Dynamic(subndarray.ptr.into());
|
let dst_ndims = deduce_ndims_after_slicing(ndims, ndslices.iter());
|
||||||
Ok(Some(result))
|
|
||||||
|
// Finally, perform the actual subscript logic
|
||||||
|
let subndarray = call_ndarray_subscript_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray,
|
||||||
|
&ndslices.iter().collect_vec(),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// ...and return the result, with two cases
|
||||||
|
let result = if dst_ndims == 0 {
|
||||||
|
// 1) ndims == 0 (this happens when you do `np.zerps((3, 4))[1, 1]`), return *THE ELEMENT*
|
||||||
|
get_ndarray_first_element(ctx, subndarray, "element")
|
||||||
|
} else {
|
||||||
|
// 2) ndims > 0 (other cases), return subndarray
|
||||||
|
subndarray.ptr.as_basic_value_enum()
|
||||||
|
};
|
||||||
|
Ok(Some(ValueEnum::Dynamic(result)))
|
||||||
|
|
||||||
// let llvm_i1 = ctx.ctx.bool_type();
|
// let llvm_i1 = ctx.ctx.bool_type();
|
||||||
// let llvm_i32 = ctx.ctx.i32_type();
|
// let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
@ -3116,7 +3143,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (elem_ty, _) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
|
let (elem_ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
|
||||||
|
|
||||||
// Get the pointer to the ndarray described by `value`
|
// Get the pointer to the ndarray described by `value`
|
||||||
let ndarray_ptr = if let Some(v) = generator.gen_expr(ctx, value)? {
|
let ndarray_ptr = if let Some(v) = generator.gen_expr(ctx, value)? {
|
||||||
|
@ -3136,7 +3163,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
let ndarray = ndarray_ty.value_from_ptr(ctx.ctx, ndarray_ptr);
|
let ndarray = ndarray_ty.value_from_ptr(ctx.ctx, ndarray_ptr);
|
||||||
|
|
||||||
// Implementation
|
// Implementation
|
||||||
return gen_ndarray_subscript_expr(generator, ctx, *elem_ty, ndarray, slice);
|
return gen_ndarray_subscript_expr(
|
||||||
|
generator, ctx, *elem_ty, *ndims, ndarray, slice,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
TypeEnum::TTuple { .. } => {
|
TypeEnum::TTuple { .. } => {
|
||||||
let index: u32 =
|
let index: u32 =
|
||||||
|
|
|
@ -18,6 +18,7 @@ use super::{
|
||||||
};
|
};
|
||||||
use crate::codegen::classes::TypedArrayLikeAccessor;
|
use crate::codegen::classes::TypedArrayLikeAccessor;
|
||||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||||
|
use crossbeam::channel::IntoIter;
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
context::Context,
|
context::Context,
|
||||||
|
@ -25,7 +26,7 @@ use inkwell::{
|
||||||
module::Module,
|
module::Module,
|
||||||
types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType, StructType},
|
types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType, StructType},
|
||||||
values::{
|
values::{
|
||||||
BasicValue, BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue,
|
AnyValue, BasicValue, BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue,
|
||||||
PointerValue,
|
PointerValue,
|
||||||
},
|
},
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
|
@ -1049,6 +1050,25 @@ pub enum NDSlice<'ctx> {
|
||||||
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools; *should* be very easy to implement
|
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools; *should* be very easy to implement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn deduce_ndims_after_slicing<'ctx, I>(ndims: u64, ndslices: I) -> u64
|
||||||
|
where
|
||||||
|
I: Iterator<Item = &'ctx NDSlice<'ctx>>,
|
||||||
|
{
|
||||||
|
let mut final_ndims = ndims;
|
||||||
|
for ndslice in ndslices {
|
||||||
|
match ndslice {
|
||||||
|
NDSlice::Index(_) => {
|
||||||
|
// Index demote the output rank by 1
|
||||||
|
final_ndims -= 1;
|
||||||
|
}
|
||||||
|
NDSlice::Slice(_) => {
|
||||||
|
// Rank isn't changed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
final_ndims
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Empty struct
|
// TODO: Empty struct
|
||||||
pub struct IrrtNDSlice {}
|
pub struct IrrtNDSlice {}
|
||||||
|
|
||||||
|
@ -1280,7 +1300,7 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>(
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn call_nac3_ndarray_deduce_ndims_after_slicing_raw<'ctx>(
|
pub fn call_nac3_ndarray_deduce_ndims_after_slicing<'ctx>(
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
size_type: IntType<'ctx>,
|
size_type: IntType<'ctx>,
|
||||||
ndims: IntValue<'ctx>,
|
ndims: IntValue<'ctx>,
|
||||||
|
@ -1365,3 +1385,39 @@ pub fn call_nac3_ndarray_subscript<'ctx>(
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_nac3_len<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NpArrayValue<'ctx>,
|
||||||
|
) -> IntValue<'ctx> {
|
||||||
|
let size_type = ndarray.ty.size_type;
|
||||||
|
|
||||||
|
// Get the IRRT function
|
||||||
|
let function = get_size_type_dependent_function(
|
||||||
|
ctx,
|
||||||
|
size_type,
|
||||||
|
"__nac3_ndarray_deduce_ndims_after_slicing",
|
||||||
|
|| {
|
||||||
|
get_sliceindex_type(ctx.ctx).fn_type(
|
||||||
|
&[
|
||||||
|
get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray<SizeT> *ndarray
|
||||||
|
],
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// Call the IRRT function
|
||||||
|
ctx.builder
|
||||||
|
.build_call(
|
||||||
|
function,
|
||||||
|
&[
|
||||||
|
ndarray.ptr.into(), // ndarray
|
||||||
|
],
|
||||||
|
"len_of_ndarray",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.try_as_basic_value()
|
||||||
|
.unwrap_left()
|
||||||
|
.into_int_value()
|
||||||
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ use nac3parser::ast::{Operator, StrRef};
|
||||||
use super::{
|
use super::{
|
||||||
classes::NpArrayValue,
|
classes::NpArrayValue,
|
||||||
irrt::{
|
irrt::{
|
||||||
call_nac3_ndarray_deduce_ndims_after_slicing_raw, call_nac3_ndarray_set_strides_by_shape,
|
call_nac3_ndarray_deduce_ndims_after_slicing, call_nac3_ndarray_set_strides_by_shape,
|
||||||
call_nac3_ndarray_size, call_nac3_ndarray_subscript, get_irrt_ndarray_ptr_type,
|
call_nac3_ndarray_size, call_nac3_ndarray_subscript, get_irrt_ndarray_ptr_type,
|
||||||
get_opaque_uint8_ptr_type, IrrtNDSlice, NDSlice,
|
get_opaque_uint8_ptr_type, IrrtNDSlice, NDSlice,
|
||||||
},
|
},
|
||||||
|
@ -2087,19 +2087,19 @@ fn copy_array_slice<'ctx, G, Src, Dst>(
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn var_alloc_ndarray<'ctx, G>(
|
fn alloca_ndarray<'ctx, G>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_type: BasicTypeEnum<'ctx>,
|
elem_type: BasicTypeEnum<'ctx>,
|
||||||
ndims: IntValue<'ctx>,
|
ndims: IntValue<'ctx>,
|
||||||
name: Option<&str>,
|
name: &str,
|
||||||
) -> Result<NpArrayValue<'ctx>, String>
|
) -> Result<NpArrayValue<'ctx>, String>
|
||||||
where
|
where
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
{
|
{
|
||||||
let size_type = generator.get_size_type(ctx.ctx);
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
let ndarray_ty = NpArrayType { size_type, elem_type };
|
let ndarray_ty = NpArrayType { size_type, elem_type };
|
||||||
let ndarray = ndarray_ty.var_alloc(generator, ctx, ndims, name);
|
let ndarray = ndarray_ty.alloca(ctx, ndims, name);
|
||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2270,7 +2270,7 @@ fn alloca_ndarray_and_init<'ctx, G>(
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
elem_type: BasicTypeEnum<'ctx>,
|
elem_type: BasicTypeEnum<'ctx>,
|
||||||
init_mode: NDArrayInitMode<'ctx, G>,
|
init_mode: NDArrayInitMode<'ctx, G>,
|
||||||
name: Option<&str>,
|
name: &str,
|
||||||
) -> Result<NpArrayValue<'ctx>, String>
|
) -> Result<NpArrayValue<'ctx>, String>
|
||||||
where
|
where
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
|
@ -2278,12 +2278,12 @@ where
|
||||||
// It is implemented verbosely in order to make the initialization modes super clear in their intent.
|
// It is implemented verbosely in order to make the initialization modes super clear in their intent.
|
||||||
match init_mode {
|
match init_mode {
|
||||||
NDArrayInitMode::SetNDim { ndim: ndims } => {
|
NDArrayInitMode::SetNDim { ndim: ndims } => {
|
||||||
let ndarray = var_alloc_ndarray(generator, ctx, elem_type, ndims, name)?;
|
let ndarray = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
|
||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
NDArrayInitMode::SetShape { shape } => {
|
NDArrayInitMode::SetShape { shape } => {
|
||||||
let ndims = shape.count;
|
let ndims = shape.count;
|
||||||
let ndarray = var_alloc_ndarray(generator, ctx, elem_type, ndims, name)?;
|
let ndarray = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
|
||||||
|
|
||||||
// Fill `ndarray.shape` with `shape_producer`
|
// Fill `ndarray.shape` with `shape_producer`
|
||||||
(shape.write_to_slice)(generator, ctx, &ndarray.shape_slice(ctx));
|
(shape.write_to_slice)(generator, ctx, &ndarray.shape_slice(ctx));
|
||||||
|
@ -2292,7 +2292,7 @@ where
|
||||||
}
|
}
|
||||||
NDArrayInitMode::SetShapeAndAllocaData { shape } => {
|
NDArrayInitMode::SetShapeAndAllocaData { shape } => {
|
||||||
let ndims = shape.count;
|
let ndims = shape.count;
|
||||||
let ndarray = var_alloc_ndarray(generator, ctx, elem_type, ndims, name)?;
|
let ndarray = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
|
||||||
|
|
||||||
// Fill `ndarray.shape` with `shape_producer`
|
// Fill `ndarray.shape` with `shape_producer`
|
||||||
(shape.write_to_slice)(generator, ctx, &ndarray.shape_slice(ctx));
|
(shape.write_to_slice)(generator, ctx, &ndarray.shape_slice(ctx));
|
||||||
|
@ -2320,6 +2320,28 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_ndarray_first_element<'ctx>(
|
||||||
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NpArrayValue<'ctx>,
|
||||||
|
name: &str,
|
||||||
|
) -> BasicValueEnum<'ctx> {
|
||||||
|
let data = ndarray.load_data(ctx);
|
||||||
|
|
||||||
|
// Cast `data` to the actual element the `subndarray` holds
|
||||||
|
// otherwise `subndarray.data` is just a bunch of `uint8_t*`
|
||||||
|
let data = ctx
|
||||||
|
.builder
|
||||||
|
.build_pointer_cast(
|
||||||
|
data,
|
||||||
|
ndarray.ty.elem_type.ptr_type(AddressSpace::default()),
|
||||||
|
"data_casted",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Load the element
|
||||||
|
ctx.builder.build_load(data, name).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn call_ndarray_subscript_impl<'ctx, G>(
|
pub fn call_ndarray_subscript_impl<'ctx, G>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -2344,10 +2366,10 @@ where
|
||||||
// Prepare the argument `slices`
|
// Prepare the argument `slices`
|
||||||
let ndslices_ptr = IrrtNDSlice::alloca_ndslices(ctx, ndslices);
|
let ndslices_ptr = IrrtNDSlice::alloca_ndslices(ctx, ndslices);
|
||||||
|
|
||||||
// Deduce the ndims
|
// Get `dst_ndims`
|
||||||
let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing_raw(
|
let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing(
|
||||||
ctx,
|
ctx,
|
||||||
ndarray.ty.size_type,
|
size_type,
|
||||||
ndims,
|
ndims,
|
||||||
num_slices,
|
num_slices,
|
||||||
ndslices_ptr,
|
ndslices_ptr,
|
||||||
|
@ -2359,7 +2381,7 @@ where
|
||||||
ctx,
|
ctx,
|
||||||
ndarray.ty.elem_type,
|
ndarray.ty.elem_type,
|
||||||
NDArrayInitMode::SetNDim { ndim: dst_ndims },
|
NDArrayInitMode::SetNDim { ndim: dst_ndims },
|
||||||
Some("subndarray"),
|
"subndarray",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
call_nac3_ndarray_subscript(ctx, ndarray, num_slices, ndslices_ptr, dst_ndarray);
|
call_nac3_ndarray_subscript(ctx, ndarray, num_slices, ndslices_ptr, dst_ndarray);
|
||||||
|
@ -2374,7 +2396,7 @@ fn call_ndarray_empty_impl<'ctx, G>(
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: BasicValueEnum<'ctx>,
|
shape: BasicValueEnum<'ctx>,
|
||||||
shape_ty: Type,
|
shape_ty: Type,
|
||||||
name: Option<&str>,
|
name: &str,
|
||||||
) -> Result<NpArrayValue<'ctx>, String>
|
) -> Result<NpArrayValue<'ctx>, String>
|
||||||
where
|
where
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
|
@ -2398,7 +2420,7 @@ fn call_ndarray_fill_impl<'ctx, G>(
|
||||||
shape: BasicValueEnum<'ctx>,
|
shape: BasicValueEnum<'ctx>,
|
||||||
shape_ty: Type,
|
shape_ty: Type,
|
||||||
fill_value: BasicValueEnum<'ctx>,
|
fill_value: BasicValueEnum<'ctx>,
|
||||||
name: Option<&str>,
|
name: &str,
|
||||||
) -> Result<NpArrayValue<'ctx>, String>
|
) -> Result<NpArrayValue<'ctx>, String>
|
||||||
where
|
where
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
|
@ -2430,7 +2452,7 @@ pub fn gen_ndarray_empty<'ctx>(
|
||||||
context.primitives.float,
|
context.primitives.float,
|
||||||
shape,
|
shape,
|
||||||
shape_ty,
|
shape_ty,
|
||||||
None,
|
"empty_ndarray",
|
||||||
)?;
|
)?;
|
||||||
Ok(ndarray.ptr)
|
Ok(ndarray.ptr)
|
||||||
}
|
}
|
||||||
|
@ -2462,7 +2484,7 @@ pub fn gen_ndarray_zeros<'ctx>(
|
||||||
shape,
|
shape,
|
||||||
shape_ty,
|
shape_ty,
|
||||||
float64_llvm_type.const_zero().as_basic_value_enum(),
|
float64_llvm_type.const_zero().as_basic_value_enum(),
|
||||||
Some("np_zeros.result"),
|
"zeros_ndarray",
|
||||||
)?;
|
)?;
|
||||||
Ok(ndarray.ptr)
|
Ok(ndarray.ptr)
|
||||||
}
|
}
|
||||||
|
@ -2494,7 +2516,7 @@ pub fn gen_ndarray_ones<'ctx>(
|
||||||
shape,
|
shape,
|
||||||
shape_ty,
|
shape_ty,
|
||||||
float64_llvm_type.const_float(1.0).as_basic_value_enum(),
|
float64_llvm_type.const_float(1.0).as_basic_value_enum(),
|
||||||
Some("np_ones.result"),
|
"ones_ndarray",
|
||||||
)?;
|
)?;
|
||||||
Ok(ndarray.ptr)
|
Ok(ndarray.ptr)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use std::iter::once;
|
use std::iter::once;
|
||||||
|
|
||||||
use crate::util::SizeVariant;
|
use crate::{codegen::classes::NpArrayType, util::SizeVariant};
|
||||||
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
|
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
|
@ -1196,7 +1196,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
let func = match prim {
|
let func = match prim {
|
||||||
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
|
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
|
||||||
PrimDef::FunNpZeros => gen_ndarray_zeros,
|
PrimDef::FunNpZeros => gen_ndarray_zeros,
|
||||||
PrimDef::FunNpOnes => todo!(), // gen_ndarray_ones,
|
PrimDef::FunNpOnes => gen_ndarray_ones, // gen_ndarray_ones,
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
|
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
|
||||||
|
@ -1460,51 +1460,62 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
// TODO: Check is unsized and throw error if so
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
let arg = NDArrayValue::from_ptr_val(
|
// Parse `arg`
|
||||||
arg.into_pointer_value(),
|
let ndarray_ptr = arg.into_pointer_value(); // It has to be an ndarray
|
||||||
llvm_usize,
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
let ndims = arg.dim_sizes().size(ctx, generator);
|
let size_type = generator.get_size_type(ctx.ctx);
|
||||||
ctx.make_assert(
|
let ndarray_ty = NpArrayType::new_opaque_elem(ctx, size_type); // We don't need to care about the element type - we only want the shape
|
||||||
generator,
|
let ndarray = ndarray_ty.value_from_ptr(ctx.ctx, ndarray_ptr);
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(
|
|
||||||
IntPredicate::NE,
|
|
||||||
ndims,
|
|
||||||
llvm_usize.const_zero(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap(),
|
|
||||||
"0:TypeError",
|
|
||||||
&format!("{name}() of unsized object", name = prim.name()),
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
let len = unsafe {
|
Some(call_nac3_len(ctx, ndarray).as_basic_value_enum())
|
||||||
arg.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_zero(),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
if len.get_type().get_bit_width() == 32 {
|
// let llvm_i32 = ctx.ctx.i32_type();
|
||||||
Some(len.into())
|
// let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
} else {
|
|
||||||
Some(
|
// let arg = NDArrayValue::from_ptr_val(
|
||||||
ctx.builder
|
// arg.into_pointer_value(),
|
||||||
.build_int_truncate(len, llvm_i32, "len")
|
// llvm_usize,
|
||||||
.map(Into::into)
|
// None,
|
||||||
.unwrap(),
|
// );
|
||||||
)
|
|
||||||
}
|
// let ndims = arg.dim_sizes().size(ctx, generator);
|
||||||
|
// ctx.make_assert(
|
||||||
|
// generator,
|
||||||
|
// ctx.builder
|
||||||
|
// .build_int_compare(
|
||||||
|
// IntPredicate::NE,
|
||||||
|
// ndims,
|
||||||
|
// llvm_usize.const_zero(),
|
||||||
|
// "",
|
||||||
|
// )
|
||||||
|
// .unwrap(),
|
||||||
|
// "0:TypeError",
|
||||||
|
// &format!("{name}() of unsized object", name = prim.name()),
|
||||||
|
// [None, None, None],
|
||||||
|
// ctx.current_loc,
|
||||||
|
// );
|
||||||
|
|
||||||
|
// let len = unsafe {
|
||||||
|
// arg.dim_sizes().get_typed_unchecked(
|
||||||
|
// ctx,
|
||||||
|
// generator,
|
||||||
|
// &llvm_usize.const_zero(),
|
||||||
|
// None,
|
||||||
|
// )
|
||||||
|
// };
|
||||||
|
|
||||||
|
// if len.get_type().get_bit_width() == 32 {
|
||||||
|
// Some(len.into())
|
||||||
|
// } else {
|
||||||
|
// Some(
|
||||||
|
// ctx.builder
|
||||||
|
// .build_int_truncate(len, llvm_i32, "len")
|
||||||
|
// .map(Into::into)
|
||||||
|
// .unwrap(),
|
||||||
|
// )
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,18 @@
|
||||||
|
@extern
|
||||||
|
def output_float64(x: float):
|
||||||
|
...
|
||||||
|
|
||||||
|
def output_ndarray_float_1(n: ndarray[float, Literal[1]]):
|
||||||
|
for i in range(len(n)):
|
||||||
|
output_float64(n[i])
|
||||||
|
|
||||||
|
def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
|
||||||
|
for r in range(len(n)):
|
||||||
|
for c in range(len(n[r])):
|
||||||
|
output_float64(n[r][c])
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
hello = np_zeros((3, 4))
|
hello = np_ones((3, 4))
|
||||||
|
# output_float64(hello[2, 3])
|
||||||
|
output_ndarray_float_1(hello[::-2, 2])
|
||||||
return 0
|
return 0
|
Loading…
Reference in New Issue