forked from M-Labs/nac3
1
0
Fork 0

WIP: core: more progress

This commit is contained in:
lyken 2024-07-12 00:41:53 +08:00
parent 635542a36d
commit e75db2c26f
8 changed files with 311 additions and 100 deletions

View File

@ -173,6 +173,8 @@ namespace {
// NOTE: Formally this should be of type `void *`, but clang // NOTE: Formally this should be of type `void *`, but clang
// translates `void *` to `i8 *` when run with `-S -emit-llvm`, // translates `void *` to `i8 *` when run with `-S -emit-llvm`,
// so we will put `uint8_t *` here for clarity. // so we will put `uint8_t *` here for clarity.
//
// This pointer should point to the first element of the ndarray directly
uint8_t *data; uint8_t *data;
// The number of bytes of a single element in `data`. // The number of bytes of a single element in `data`.
@ -308,6 +310,7 @@ namespace {
irrt_assert(dst_ndarray->ndims == ndarray_util::deduce_ndims_after_slicing(this->ndims, num_ndslices, ndslices)); irrt_assert(dst_ndarray->ndims == ndarray_util::deduce_ndims_after_slicing(this->ndims, num_ndslices, ndslices));
dst_ndarray->data = this->data; dst_ndarray->data = this->data;
dst_ndarray->itemsize = this->itemsize;
SizeT this_axis = 0; SizeT this_axis = 0;
SizeT dst_axis = 0; SizeT dst_axis = 0;
@ -346,7 +349,18 @@ namespace {
} }
} }
irrt_assert(dst_axis == dst_ndarray->ndims); // Sanity check on the implementation /*
Reference python code:
```python
dst_ndarray.shape.extend(this.shape[this_axis:])
dst_ndarray.strides.extend(this.strides[this_axis:])
```
*/
for (; dst_axis < dst_ndarray->ndims; dst_axis++, this_axis++) {
dst_ndarray->shape[dst_axis] = this->shape[this_axis];
dst_ndarray->strides[dst_axis] = this->strides[this_axis];
}
} }
// Similar to `np.broadcast_to(<ndarray>, <target_shape>)` // Similar to `np.broadcast_to(<ndarray>, <target_shape>)`
@ -435,6 +449,23 @@ namespace {
this->set_pelement_value(this_pelement, src_pelement); this->set_pelement_value(this_pelement, src_pelement);
} }
} }
// TODO: DOCUMENT ME
bool is_unsized() {
return this->ndims == 0;
}
// Simulate `len(<ndarray>)`
// See (it doesn't help): https://numpy.org/doc/stable/reference/generated/numpy.ndarray.__len__.html#numpy.ndarray.__len__
SliceIndex len() {
// If you do `len(np.asarray(42))` (note that its `.shape` is just `()` - an empty tuple),
// numpy throws a `TypeError: len() of unsized object`
irrt_assert(!this->is_unsized());
// Apparently `len(<ndarray>)` is defined to be the first dimension
// REFERENCE: https://stackoverflow.com/questions/43081809/len-of-a-numpy-array-in-python
return (SliceIndex) this->shape[0];
}
}; };
} }
@ -478,4 +509,12 @@ extern "C" {
void __nac3_ndarray_subscript64(NDArray<int64_t>* ndarray, int32_t num_slices, NDSlice* slices, NDArray<int64_t> *dst_ndarray) { void __nac3_ndarray_subscript64(NDArray<int64_t>* ndarray, int32_t num_slices, NDSlice* slices, NDArray<int64_t> *dst_ndarray) {
ndarray->subscript(num_slices, slices, dst_ndarray); ndarray->subscript(num_slices, slices, dst_ndarray);
} }
SliceIndex __nac3_ndarray_len(NDArray<int32_t>* ndarray) {
return ndarray->len();
}
SliceIndex __nac3_ndarray_len64(NDArray<int64_t>* ndarray) {
return ndarray->len();
}
} }

View File

@ -475,6 +475,50 @@ void test_ndslice_2() {
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1 }))); assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1 })));
} }
void test_ndslice_3() {
BEGIN_TEST();
double in_data[12] = { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 };
const int32_t in_itemsize = sizeof(double);
const int32_t in_ndims = 2;
int32_t in_shape[in_ndims] = { 3, 4 };
int32_t in_strides[in_ndims] = {};
NDArray<int32_t> ndarray = {
.data = (uint8_t*) in_data,
.itemsize = in_itemsize,
.ndims = in_ndims,
.shape = in_shape,
.strides = in_strides
};
ndarray.set_strides_by_shape();
const int32_t dst_ndims = 2;
int32_t dst_shape[dst_ndims] = {999, 999}; // Empty values
int32_t dst_strides[dst_ndims] = {999, 999}; // Empty values
NDArray<int32_t> dst_ndarray = {
.data = nullptr,
.ndims = dst_ndims,
.shape = dst_shape,
.strides = dst_strides
};
// Create the slice in `ndarray[2:3]`
UserSlice user_slice_1 = {
.start_defined = 1,
.start = 2,
.stop_defined = 1,
.stop = 3,
.step_defined = 0,
};
const int32_t num_ndslices = 1;
NDSlice ndslices[num_ndslices] = {
{ .type = INPUT_SLICE_TYPE_SLICE, .slice = (uint8_t*) &user_slice_1 },
};
ndarray.subscript(num_ndslices, ndslices, &dst_ndarray);
}
void test_can_broadcast_shape() { void test_can_broadcast_shape() {
BEGIN_TEST(); BEGIN_TEST();
@ -644,6 +688,7 @@ int main() {
test_slice_4(); test_slice_4();
test_ndslice_1(); test_ndslice_1();
test_ndslice_2(); test_ndslice_2();
test_ndslice_3();
test_can_broadcast_shape(); test_can_broadcast_shape();
test_ndarray_broadcast_1(); test_ndarray_broadcast_1();
return 0; return 0;

View File

@ -1979,43 +1979,32 @@ impl<'ctx> NpArrayType<'ctx> {
/// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`. /// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`.
/// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`, /// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`,
/// all with empty/uninitialized values. /// all with empty/uninitialized values.
pub fn var_alloc<G>( pub fn alloca(
&self, &self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
in_ndims: IntValue<'ctx>, in_ndims: IntValue<'ctx>,
name: Option<&str>, name: &str,
) -> NpArrayValue<'ctx> ) -> NpArrayValue<'ctx> {
where let ptr = ctx
G: CodeGenerator + ?Sized, .builder
{ .build_alloca(self.get_struct_type(ctx.ctx).as_basic_type_enum(), name)
let ptr = generator
.gen_var_alloc(ctx, self.get_struct_type(ctx.ctx).as_basic_type_enum(), name)
.unwrap(); .unwrap();
// Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides` // Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides`
let allocated_shape = generator let allocated_shape = ctx
.gen_array_var_alloc( .builder
ctx, .build_array_alloca(self.size_type.as_basic_type_enum(), in_ndims, "allocated_shape")
self.size_type.as_basic_type_enum(),
in_ndims,
Some("allocated_shape"),
)
.unwrap(); .unwrap();
let allocated_strides = generator let allocated_strides = ctx
.gen_array_var_alloc( .builder
ctx, .build_array_alloca(self.size_type.as_basic_type_enum(), in_ndims, "allocated_strides")
self.size_type.as_basic_type_enum(),
in_ndims,
Some("allocated_strides"),
)
.unwrap(); .unwrap();
let value = NpArrayValue { ty: *self, ptr }; let value = NpArrayValue { ty: *self, ptr };
value.store_ndims(ctx, in_ndims); value.store_ndims(ctx, in_ndims);
value.store_itemsize(ctx, self.elem_type.size_of().unwrap()); value.store_itemsize(ctx, self.elem_type.size_of().unwrap());
value.store_shape(ctx, allocated_shape.base_ptr(ctx, generator)); value.store_shape(ctx, allocated_shape);
value.store_strides(ctx, allocated_strides.base_ptr(ctx, generator)); value.store_strides(ctx, allocated_strides);
return value; return value;
} }
@ -2045,6 +2034,11 @@ pub struct NpArrayValue<'ctx> {
} }
impl<'ctx> NpArrayValue<'ctx> { impl<'ctx> NpArrayValue<'ctx> {
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let field = self.ty.fields(ctx.ctx).data;
field.load(ctx, self.ptr).into_pointer_value()
}
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, new_data_ptr: PointerValue<'ctx>) { pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, new_data_ptr: PointerValue<'ctx>) {
let field = self.ty.fields(ctx.ctx).data; let field = self.ty.fields(ctx.ctx).data;
field.store(ctx, self.ptr, new_data_ptr); field.store(ctx, self.ptr, new_data_ptr);

View File

@ -17,7 +17,8 @@ use crate::{
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
call_memcpy_generic, call_memcpy_generic,
}, },
need_sret, numpy::{self, call_ndarray_subscript_impl}, need_sret,
numpy::{self, call_ndarray_subscript_impl, get_ndarray_first_element},
stmt::{ stmt::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
gen_var, gen_var,
@ -38,7 +39,7 @@ use crate::{
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
types::{AnyType, BasicType, BasicTypeEnum}, types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue}, values::{BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use itertools::{chain, izip, Either, Itertools}; use itertools::{chain, izip, Either, Itertools};
@ -2100,12 +2101,14 @@ pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
/// Generates code for a subscript expression on an `ndarray`. /// Generates code for a subscript expression on an `ndarray`.
/// ///
/// * `ty` - The `Type` of the `NDArray` elements. /// * `ty` - The `Type` of the `NDArray` elements.
/// * `ndims` - The `Type` of the `NDArray` number-of-dimensions `Literal`.
/// * `ndarray` - The `NDArray` value. /// * `ndarray` - The `NDArray` value.
/// * `slice` - The slice expression used to subscript into the `ndarray`. /// * `slice` - The slice expression used to subscript into the `ndarray`.
fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
ty: Type, ty: Type,
ndims: Type,
ndarray: NpArrayValue<'ctx>, ndarray: NpArrayValue<'ctx>,
slice: &Expr<Option<Type>>, slice: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> { ) -> Result<Option<ValueEnum<'ctx>>, String> {
@ -2165,8 +2168,8 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let stop = help(stop)?; let stop = help(stop)?;
let step = help(step)?; let step = help(step)?;
// NOTE: Now start stop step should all be 32-bit ints after typechecking // start stop step should all be 32-bit ints after typechecking,
// ...and `IrrtUserSlice` expects `int32`s // and `IrrtUserSlice` expects `int32`s
NDSlice::Slice(UserSlice { start, stop, step }) NDSlice::Slice(UserSlice { start, stop, step })
} }
_ => { _ => {
@ -2185,12 +2188,36 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ndslices.push(ndslice); ndslices.push(ndslice);
} }
// Finally, perform the actual subscript logic // TODO: what is going on? why the original implementation doesn't assert `ndims_values.len() == 1`
let subndarray = call_ndarray_subscript_impl(generator, ctx, ndarray, &ndslices.iter().collect_vec())?; // Extract the `ndims` from a `Type` to `i128`
let TypeEnum::TLiteral { values: ndims_values, .. } = &*ctx.unifier.get_ty_immutable(ndims)
else {
unreachable!()
};
assert_eq!(ndims_values.len(), 1);
let ndims = i128::try_from(ndims_values[0].clone()).unwrap() as u64;
assert!(ndims > 0);
// ...and return the result // Deduce the subndarray's ndims
let result = ValueEnum::Dynamic(subndarray.ptr.into()); let dst_ndims = deduce_ndims_after_slicing(ndims, ndslices.iter());
Ok(Some(result))
// Finally, perform the actual subscript logic
let subndarray = call_ndarray_subscript_impl(
generator,
ctx,
ndarray,
&ndslices.iter().collect_vec(),
)?;
// ...and return the result, with two cases
let result = if dst_ndims == 0 {
// 1) ndims == 0 (this happens when you do `np.zerps((3, 4))[1, 1]`), return *THE ELEMENT*
get_ndarray_first_element(ctx, subndarray, "element")
} else {
// 2) ndims > 0 (other cases), return subndarray
subndarray.ptr.as_basic_value_enum()
};
Ok(Some(ValueEnum::Dynamic(result)))
// let llvm_i1 = ctx.ctx.bool_type(); // let llvm_i1 = ctx.ctx.bool_type();
// let llvm_i32 = ctx.ctx.i32_type(); // let llvm_i32 = ctx.ctx.i32_type();
@ -3116,7 +3143,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} }
} }
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::NDArray.id() => {
let (elem_ty, _) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap(); let (elem_ty, ndims) = params.iter().map(|(_, ty)| ty).collect_tuple().unwrap();
// Get the pointer to the ndarray described by `value` // Get the pointer to the ndarray described by `value`
let ndarray_ptr = if let Some(v) = generator.gen_expr(ctx, value)? { let ndarray_ptr = if let Some(v) = generator.gen_expr(ctx, value)? {
@ -3136,7 +3163,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
let ndarray = ndarray_ty.value_from_ptr(ctx.ctx, ndarray_ptr); let ndarray = ndarray_ty.value_from_ptr(ctx.ctx, ndarray_ptr);
// Implementation // Implementation
return gen_ndarray_subscript_expr(generator, ctx, *elem_ty, ndarray, slice); return gen_ndarray_subscript_expr(
generator, ctx, *elem_ty, *ndims, ndarray, slice,
);
} }
TypeEnum::TTuple { .. } => { TypeEnum::TTuple { .. } => {
let index: u32 = let index: u32 =

View File

@ -18,6 +18,7 @@ use super::{
}; };
use crate::codegen::classes::TypedArrayLikeAccessor; use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use crossbeam::channel::IntoIter;
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
context::Context, context::Context,
@ -25,7 +26,7 @@ use inkwell::{
module::Module, module::Module,
types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType, StructType}, types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType, StructType},
values::{ values::{
BasicValue, BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue, AnyValue, BasicValue, BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue,
PointerValue, PointerValue,
}, },
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
@ -1049,6 +1050,25 @@ pub enum NDSlice<'ctx> {
// TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools; *should* be very easy to implement // TODO: Support https://numpy.org/doc/stable/user/basics.indexing.html#dimensional-indexing-tools; *should* be very easy to implement
} }
pub fn deduce_ndims_after_slicing<'ctx, I>(ndims: u64, ndslices: I) -> u64
where
I: Iterator<Item = &'ctx NDSlice<'ctx>>,
{
let mut final_ndims = ndims;
for ndslice in ndslices {
match ndslice {
NDSlice::Index(_) => {
// Index demote the output rank by 1
final_ndims -= 1;
}
NDSlice::Slice(_) => {
// Rank isn't changed
}
}
}
final_ndims
}
// TODO: Empty struct // TODO: Empty struct
pub struct IrrtNDSlice {} pub struct IrrtNDSlice {}
@ -1280,7 +1300,7 @@ pub fn call_nac3_ndarray_set_strides_by_shape<'ctx>(
.unwrap(); .unwrap();
} }
pub fn call_nac3_ndarray_deduce_ndims_after_slicing_raw<'ctx>( pub fn call_nac3_ndarray_deduce_ndims_after_slicing<'ctx>(
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
size_type: IntType<'ctx>, size_type: IntType<'ctx>,
ndims: IntValue<'ctx>, ndims: IntValue<'ctx>,
@ -1365,3 +1385,39 @@ pub fn call_nac3_ndarray_subscript<'ctx>(
) )
.unwrap(); .unwrap();
} }
pub fn call_nac3_len<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NpArrayValue<'ctx>,
) -> IntValue<'ctx> {
let size_type = ndarray.ty.size_type;
// Get the IRRT function
let function = get_size_type_dependent_function(
ctx,
size_type,
"__nac3_ndarray_deduce_ndims_after_slicing",
|| {
get_sliceindex_type(ctx.ctx).fn_type(
&[
get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray<SizeT> *ndarray
],
false,
)
},
);
// Call the IRRT function
ctx.builder
.build_call(
function,
&[
ndarray.ptr.into(), // ndarray
],
"len_of_ndarray",
)
.unwrap()
.try_as_basic_value()
.unwrap_left()
.into_int_value()
}

View File

@ -36,7 +36,7 @@ use nac3parser::ast::{Operator, StrRef};
use super::{ use super::{
classes::NpArrayValue, classes::NpArrayValue,
irrt::{ irrt::{
call_nac3_ndarray_deduce_ndims_after_slicing_raw, call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_deduce_ndims_after_slicing, call_nac3_ndarray_set_strides_by_shape,
call_nac3_ndarray_size, call_nac3_ndarray_subscript, get_irrt_ndarray_ptr_type, call_nac3_ndarray_size, call_nac3_ndarray_subscript, get_irrt_ndarray_ptr_type,
get_opaque_uint8_ptr_type, IrrtNDSlice, NDSlice, get_opaque_uint8_ptr_type, IrrtNDSlice, NDSlice,
}, },
@ -2087,19 +2087,19 @@ fn copy_array_slice<'ctx, G, Src, Dst>(
.unwrap(); .unwrap();
} }
fn var_alloc_ndarray<'ctx, G>( fn alloca_ndarray<'ctx, G>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_type: BasicTypeEnum<'ctx>, elem_type: BasicTypeEnum<'ctx>,
ndims: IntValue<'ctx>, ndims: IntValue<'ctx>,
name: Option<&str>, name: &str,
) -> Result<NpArrayValue<'ctx>, String> ) -> Result<NpArrayValue<'ctx>, String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
{ {
let size_type = generator.get_size_type(ctx.ctx); let size_type = generator.get_size_type(ctx.ctx);
let ndarray_ty = NpArrayType { size_type, elem_type }; let ndarray_ty = NpArrayType { size_type, elem_type };
let ndarray = ndarray_ty.var_alloc(generator, ctx, ndims, name); let ndarray = ndarray_ty.alloca(ctx, ndims, name);
Ok(ndarray) Ok(ndarray)
} }
@ -2270,7 +2270,7 @@ fn alloca_ndarray_and_init<'ctx, G>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_type: BasicTypeEnum<'ctx>, elem_type: BasicTypeEnum<'ctx>,
init_mode: NDArrayInitMode<'ctx, G>, init_mode: NDArrayInitMode<'ctx, G>,
name: Option<&str>, name: &str,
) -> Result<NpArrayValue<'ctx>, String> ) -> Result<NpArrayValue<'ctx>, String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
@ -2278,12 +2278,12 @@ where
// It is implemented verbosely in order to make the initialization modes super clear in their intent. // It is implemented verbosely in order to make the initialization modes super clear in their intent.
match init_mode { match init_mode {
NDArrayInitMode::SetNDim { ndim: ndims } => { NDArrayInitMode::SetNDim { ndim: ndims } => {
let ndarray = var_alloc_ndarray(generator, ctx, elem_type, ndims, name)?; let ndarray = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
Ok(ndarray) Ok(ndarray)
} }
NDArrayInitMode::SetShape { shape } => { NDArrayInitMode::SetShape { shape } => {
let ndims = shape.count; let ndims = shape.count;
let ndarray = var_alloc_ndarray(generator, ctx, elem_type, ndims, name)?; let ndarray = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
// Fill `ndarray.shape` with `shape_producer` // Fill `ndarray.shape` with `shape_producer`
(shape.write_to_slice)(generator, ctx, &ndarray.shape_slice(ctx)); (shape.write_to_slice)(generator, ctx, &ndarray.shape_slice(ctx));
@ -2292,7 +2292,7 @@ where
} }
NDArrayInitMode::SetShapeAndAllocaData { shape } => { NDArrayInitMode::SetShapeAndAllocaData { shape } => {
let ndims = shape.count; let ndims = shape.count;
let ndarray = var_alloc_ndarray(generator, ctx, elem_type, ndims, name)?; let ndarray = alloca_ndarray(generator, ctx, elem_type, ndims, name)?;
// Fill `ndarray.shape` with `shape_producer` // Fill `ndarray.shape` with `shape_producer`
(shape.write_to_slice)(generator, ctx, &ndarray.shape_slice(ctx)); (shape.write_to_slice)(generator, ctx, &ndarray.shape_slice(ctx));
@ -2320,6 +2320,28 @@ where
} }
} }
pub fn get_ndarray_first_element<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NpArrayValue<'ctx>,
name: &str,
) -> BasicValueEnum<'ctx> {
let data = ndarray.load_data(ctx);
// Cast `data` to the actual element the `subndarray` holds
// otherwise `subndarray.data` is just a bunch of `uint8_t*`
let data = ctx
.builder
.build_pointer_cast(
data,
ndarray.ty.elem_type.ptr_type(AddressSpace::default()),
"data_casted",
)
.unwrap();
// Load the element
ctx.builder.build_load(data, name).unwrap()
}
pub fn call_ndarray_subscript_impl<'ctx, G>( pub fn call_ndarray_subscript_impl<'ctx, G>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
@ -2344,10 +2366,10 @@ where
// Prepare the argument `slices` // Prepare the argument `slices`
let ndslices_ptr = IrrtNDSlice::alloca_ndslices(ctx, ndslices); let ndslices_ptr = IrrtNDSlice::alloca_ndslices(ctx, ndslices);
// Deduce the ndims // Get `dst_ndims`
let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing_raw( let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing(
ctx, ctx,
ndarray.ty.size_type, size_type,
ndims, ndims,
num_slices, num_slices,
ndslices_ptr, ndslices_ptr,
@ -2359,7 +2381,7 @@ where
ctx, ctx,
ndarray.ty.elem_type, ndarray.ty.elem_type,
NDArrayInitMode::SetNDim { ndim: dst_ndims }, NDArrayInitMode::SetNDim { ndim: dst_ndims },
Some("subndarray"), "subndarray",
)?; )?;
call_nac3_ndarray_subscript(ctx, ndarray, num_slices, ndslices_ptr, dst_ndarray); call_nac3_ndarray_subscript(ctx, ndarray, num_slices, ndslices_ptr, dst_ndarray);
@ -2374,7 +2396,7 @@ fn call_ndarray_empty_impl<'ctx, G>(
elem_ty: Type, elem_ty: Type,
shape: BasicValueEnum<'ctx>, shape: BasicValueEnum<'ctx>,
shape_ty: Type, shape_ty: Type,
name: Option<&str>, name: &str,
) -> Result<NpArrayValue<'ctx>, String> ) -> Result<NpArrayValue<'ctx>, String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
@ -2398,7 +2420,7 @@ fn call_ndarray_fill_impl<'ctx, G>(
shape: BasicValueEnum<'ctx>, shape: BasicValueEnum<'ctx>,
shape_ty: Type, shape_ty: Type,
fill_value: BasicValueEnum<'ctx>, fill_value: BasicValueEnum<'ctx>,
name: Option<&str>, name: &str,
) -> Result<NpArrayValue<'ctx>, String> ) -> Result<NpArrayValue<'ctx>, String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
@ -2430,7 +2452,7 @@ pub fn gen_ndarray_empty<'ctx>(
context.primitives.float, context.primitives.float,
shape, shape,
shape_ty, shape_ty,
None, "empty_ndarray",
)?; )?;
Ok(ndarray.ptr) Ok(ndarray.ptr)
} }
@ -2462,7 +2484,7 @@ pub fn gen_ndarray_zeros<'ctx>(
shape, shape,
shape_ty, shape_ty,
float64_llvm_type.const_zero().as_basic_value_enum(), float64_llvm_type.const_zero().as_basic_value_enum(),
Some("np_zeros.result"), "zeros_ndarray",
)?; )?;
Ok(ndarray.ptr) Ok(ndarray.ptr)
} }
@ -2494,7 +2516,7 @@ pub fn gen_ndarray_ones<'ctx>(
shape, shape,
shape_ty, shape_ty,
float64_llvm_type.const_float(1.0).as_basic_value_enum(), float64_llvm_type.const_float(1.0).as_basic_value_enum(),
Some("np_ones.result"), "ones_ndarray",
)?; )?;
Ok(ndarray.ptr) Ok(ndarray.ptr)
} }

View File

@ -1,6 +1,6 @@
use std::iter::once; use std::iter::once;
use crate::util::SizeVariant; use crate::{codegen::classes::NpArrayType, util::SizeVariant};
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails}; use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
use indexmap::IndexMap; use indexmap::IndexMap;
use inkwell::{ use inkwell::{
@ -1196,7 +1196,7 @@ impl<'a> BuiltinBuilder<'a> {
let func = match prim { let func = match prim {
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty, PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
PrimDef::FunNpZeros => gen_ndarray_zeros, PrimDef::FunNpZeros => gen_ndarray_zeros,
PrimDef::FunNpOnes => todo!(), // gen_ndarray_ones, PrimDef::FunNpOnes => gen_ndarray_ones, // gen_ndarray_ones,
_ => unreachable!(), _ => unreachable!(),
}; };
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum())) func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
@ -1460,51 +1460,62 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i32 = ctx.ctx.i32_type(); // TODO: Check is unsized and throw error if so
let llvm_usize = generator.get_size_type(ctx.ctx);
let arg = NDArrayValue::from_ptr_val( // Parse `arg`
arg.into_pointer_value(), let ndarray_ptr = arg.into_pointer_value(); // It has to be an ndarray
llvm_usize,
None,
);
let ndims = arg.dim_sizes().size(ctx, generator); let size_type = generator.get_size_type(ctx.ctx);
ctx.make_assert( let ndarray_ty = NpArrayType::new_opaque_elem(ctx, size_type); // We don't need to care about the element type - we only want the shape
generator, let ndarray = ndarray_ty.value_from_ptr(ctx.ctx, ndarray_ptr);
ctx.builder
.build_int_compare(
IntPredicate::NE,
ndims,
llvm_usize.const_zero(),
"",
)
.unwrap(),
"0:TypeError",
&format!("{name}() of unsized object", name = prim.name()),
[None, None, None],
ctx.current_loc,
);
let len = unsafe { Some(call_nac3_len(ctx, ndarray).as_basic_value_enum())
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
};
if len.get_type().get_bit_width() == 32 { // let llvm_i32 = ctx.ctx.i32_type();
Some(len.into()) // let llvm_usize = generator.get_size_type(ctx.ctx);
} else {
Some( // let arg = NDArrayValue::from_ptr_val(
ctx.builder // arg.into_pointer_value(),
.build_int_truncate(len, llvm_i32, "len") // llvm_usize,
.map(Into::into) // None,
.unwrap(), // );
)
} // let ndims = arg.dim_sizes().size(ctx, generator);
// ctx.make_assert(
// generator,
// ctx.builder
// .build_int_compare(
// IntPredicate::NE,
// ndims,
// llvm_usize.const_zero(),
// "",
// )
// .unwrap(),
// "0:TypeError",
// &format!("{name}() of unsized object", name = prim.name()),
// [None, None, None],
// ctx.current_loc,
// );
// let len = unsafe {
// arg.dim_sizes().get_typed_unchecked(
// ctx,
// generator,
// &llvm_usize.const_zero(),
// None,
// )
// };
// if len.get_type().get_bit_width() == 32 {
// Some(len.into())
// } else {
// Some(
// ctx.builder
// .build_int_truncate(len, llvm_i32, "len")
// .map(Into::into)
// .unwrap(),
// )
// }
} }
_ => unreachable!(), _ => unreachable!(),
} }

View File

@ -1,3 +1,18 @@
@extern
def output_float64(x: float):
...
def output_ndarray_float_1(n: ndarray[float, Literal[1]]):
for i in range(len(n)):
output_float64(n[i])
def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
for r in range(len(n)):
for c in range(len(n[r])):
output_float64(n[r][c])
def run() -> int32: def run() -> int32:
hello = np_zeros((3, 4)) hello = np_ones((3, 4))
# output_float64(hello[2, 3])
output_ndarray_float_1(hello[::-2, 2])
return 0 return 0