forked from M-Labs/nac3
1
0
Fork 0

Compare commits

..

9 Commits

Author SHA1 Message Date
lyken e719d9396d asd 2024-07-11 00:34:06 +08:00
lyken 27f2e8b391 core: irrt general numpy broadcasting 2024-07-10 20:10:14 +08:00
lyken 5f4c406b37 core: irrt general numpy slicing 2024-07-10 20:10:14 +08:00
lyken 31ab9675ca core: more irrt 2024-07-10 20:10:14 +08:00
lyken 5fd5d65377 core: build.rs rewrite regex to capture `= type` 2024-07-10 20:10:14 +08:00
lyken 01042aecfb core: move irrt c++ sources to /nac3core/irrt 2024-07-10 20:10:14 +08:00
lyken 8754f252f6 core: IRRT -Werror=return-type 2024-07-10 20:10:14 +08:00
lyken 17207a4ebe core: add irrt_test 2024-07-10 20:10:14 +08:00
lyken e3a4675fc6 core: comment out numpy 2024-07-10 20:10:05 +08:00
12 changed files with 989 additions and 566 deletions

View File

@ -212,11 +212,11 @@ namespace {
return this->size() * itemsize; return this->size() * itemsize;
} }
void set_value_at_pelement(uint8_t* pelement, const uint8_t* pvalue) { void set_pelement_value(uint8_t* pelement, const uint8_t* pvalue) {
__builtin_memcpy(pelement, pvalue, itemsize); __builtin_memcpy(pelement, pvalue, itemsize);
} }
uint8_t* get_pelement(const SizeT *indices) { uint8_t* get_pelement_by_indices(const SizeT *indices) {
uint8_t* element = data; uint8_t* element = data;
for (SizeT dim_i = 0; dim_i < ndims; dim_i++) for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
element += indices[dim_i] * strides[dim_i]; element += indices[dim_i] * strides[dim_i];
@ -229,7 +229,7 @@ namespace {
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims); SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims);
ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth); ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth);
return get_pelement(indices); return get_pelement_by_indices(indices);
} }
// Get pointer to the first element of this ndarray, assuming // Get pointer to the first element of this ndarray, assuming
@ -259,8 +259,8 @@ namespace {
iter.set_indices_zero(); iter.set_indices_zero();
for (SizeT i = 0; i < this->size(); i++, iter.next()) { for (SizeT i = 0; i < this->size(); i++, iter.next()) {
uint8_t* pelement = get_pelement(iter.indices); uint8_t* pelement = get_pelement_by_indices(iter.indices);
set_value_at_pelement(pelement, pvalue); set_pelement_value(pelement, pvalue);
} }
} }
@ -283,8 +283,8 @@ namespace {
if (!in_bounds(indices)) continue; if (!in_bounds(indices)) continue;
uint8_t* pelement = get_pelement(indices); uint8_t* pelement = get_pelement_by_indices(indices);
set_value_at_pelement(pelement, one_pvalue); set_pelement_value(pelement, one_pvalue);
} }
} }
@ -435,9 +435,9 @@ namespace {
}; };
const SizeT this_size = this->size(); const SizeT this_size = this->size();
for (SizeT i = 0; i < this_size; i++, iter.next()) { for (SizeT i = 0; i < this_size; i++, iter.next()) {
uint8_t* src_pelement = broadcasted_src_ndarray_strides->get_pelement(indices); uint8_t* src_pelement = broadcasted_src_ndarray_strides->get_pelement_by_indices(indices);
uint8_t* this_pelement = this->get_pelement(indices); uint8_t* this_pelement = this->get_pelement_by_indices(indices);
this->set_value_at_pelement(src_pelement, src_pelement); this->set_pelement_value(src_pelement, src_pelement);
} }
} }
}; };

View File

@ -81,7 +81,7 @@ void __print_ndarray_aux(const char *format, bool first, bool last, SizeT* curso
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims); SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims);
for (SizeT i = 0; i < dim; i++) { for (SizeT i = 0; i < dim; i++) {
ndarray_util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor); ndarray_util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor);
ElementT* pelement = (ElementT*) ndarray->get_pelement(indices); ElementT* pelement = (ElementT*) ndarray->get_pelement_by_indices(indices);
ElementT element = *pelement; ElementT element = *pelement;
if (i != 0) printf(", "); // List delimiter if (i != 0) printf(", "); // List delimiter
@ -394,10 +394,10 @@ void test_ndslice_1() {
assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape); assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides); assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
assert_values_match("dst_ndarray[0, 0]", "%f", 5.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 0 }))); assert_values_match("dst_ndarray[0, 0]", "%f", 5.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0, 0 })));
assert_values_match("dst_ndarray[0, 1]", "%f", 7.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 1 }))); assert_values_match("dst_ndarray[0, 1]", "%f", 7.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0, 1 })));
assert_values_match("dst_ndarray[1, 0]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 0 }))); assert_values_match("dst_ndarray[1, 0]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1, 0 })));
assert_values_match("dst_ndarray[1, 1]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 1 }))); assert_values_match("dst_ndarray[1, 1]", "%f", 11.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1, 1 })));
} }
void test_ndslice_2() { void test_ndslice_2() {
@ -471,8 +471,8 @@ void test_ndslice_2() {
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides); assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
// [5.0, 3.0] // [5.0, 3.0]
assert_values_match("dst_ndarray[0]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0 }))); assert_values_match("dst_ndarray[0]", "%f", 11.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0 })));
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1 }))); assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1 })));
} }
void test_can_broadcast_shape() { void test_can_broadcast_shape() {
@ -618,24 +618,15 @@ void test_ndarray_broadcast_1() {
assert_arrays_match("dst_ndarray->strides", "%d", dst_ndims, (int32_t[]) { 0, 0, 8 }, dst_ndarray.strides); assert_arrays_match("dst_ndarray->strides", "%d", dst_ndims, (int32_t[]) { 0, 0, 8 }, dst_ndarray.strides);
assert_values_match("dst_ndarray[0, 0, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 0}))); assert_values_match("dst_ndarray[0, 0, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 0})));
assert_values_match("dst_ndarray[0, 0, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 1}))); assert_values_match("dst_ndarray[0, 0, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 1})));
assert_values_match("dst_ndarray[0, 0, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 2}))); assert_values_match("dst_ndarray[0, 0, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 2})));
assert_values_match("dst_ndarray[0, 0, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 3}))); assert_values_match("dst_ndarray[0, 0, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 3})));
assert_values_match("dst_ndarray[0, 1, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 0}))); assert_values_match("dst_ndarray[0, 1, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 0})));
assert_values_match("dst_ndarray[0, 1, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 1}))); assert_values_match("dst_ndarray[0, 1, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 1})));
assert_values_match("dst_ndarray[0, 1, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 2}))); assert_values_match("dst_ndarray[0, 1, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 2})));
assert_values_match("dst_ndarray[0, 1, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 3}))); assert_values_match("dst_ndarray[0, 1, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 3})));
assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {1, 2, 3}))); assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {1, 2, 3})));
}
void test_assign_with() {
/*
```
xs = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=np.float64)
ys = xs.shape
```
*/
} }
int main() { int main() {
@ -653,6 +644,5 @@ int main() {
test_ndslice_2(); test_ndslice_2();
test_can_broadcast_shape(); test_can_broadcast_shape();
test_ndarray_broadcast_1(); test_ndarray_broadcast_1();
test_assign_with();
return 0; return 0;
} }

View File

@ -702,53 +702,54 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); todo!()
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); // let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
// let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); // let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); // let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { // if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx // let n_sz_eqz = ctx
.builder // .builder
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") // .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
.unwrap(); // .unwrap();
ctx.make_assert( // ctx.make_assert(
generator, // generator,
n_sz_eqz, // n_sz_eqz,
"0:ValueError", // "0:ValueError",
"zero-size array to reduction operation minimum which has no identity", // "zero-size array to reduction operation minimum which has no identity",
[None, None, None], // [None, None, None],
ctx.current_loc, // ctx.current_loc,
); // );
} // }
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; // let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
unsafe { // unsafe {
let identity = // let identity =
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); // n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap(); // ctx.builder.build_store(accumulator_addr, identity).unwrap();
} // }
gen_for_callback_incrementing( // gen_for_callback_incrementing(
generator, // generator,
ctx, // ctx,
llvm_usize.const_int(1, false), // llvm_usize.const_int(1, false),
(n_sz, false), // (n_sz, false),
|generator, ctx, _, idx| { // |generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; // let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); // let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem)); // let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem));
ctx.builder.build_store(accumulator_addr, result).unwrap(); // ctx.builder.build_store(accumulator_addr, result).unwrap();
Ok(()) // Ok(())
}, // },
llvm_usize.const_int(1, false), // llvm_usize.const_int(1, false),
)?; // )?;
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); // let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator // accumulator
} }
_ => unsupported_type(ctx, FN_NAME, &[a_ty]), _ => unsupported_type(ctx, FN_NAME, &[a_ty]),
@ -920,53 +921,54 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n) BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) => if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{ {
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty); todo!()
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); // let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
// let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None); // let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None)); // let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { // if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx // let n_sz_eqz = ctx
.builder // .builder
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "") // .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
.unwrap(); // .unwrap();
ctx.make_assert( // ctx.make_assert(
generator, // generator,
n_sz_eqz, // n_sz_eqz,
"0:ValueError", // "0:ValueError",
"zero-size array to reduction operation minimum which has no identity", // "zero-size array to reduction operation minimum which has no identity",
[None, None, None], // [None, None, None],
ctx.current_loc, // ctx.current_loc,
); // );
} // }
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; // let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
unsafe { // unsafe {
let identity = // let identity =
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None); // n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap(); // ctx.builder.build_store(accumulator_addr, identity).unwrap();
} // }
gen_for_callback_incrementing( // gen_for_callback_incrementing(
generator, // generator,
ctx, // ctx,
llvm_usize.const_int(1, false), // llvm_usize.const_int(1, false),
(n_sz, false), // (n_sz, false),
|generator, ctx, _, idx| { // |generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) }; // let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); // let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem)); // let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem));
ctx.builder.build_store(accumulator_addr, result).unwrap(); // ctx.builder.build_store(accumulator_addr, result).unwrap();
Ok(()) // Ok(())
}, // },
llvm_usize.const_int(1, false), // llvm_usize.const_int(1, false),
)?; // )?;
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap(); // let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator // accumulator
} }
_ => unsupported_type(ctx, FN_NAME, &[a_ty]), _ => unsupported_type(ctx, FN_NAME, &[a_ty]),

View File

@ -1,5 +1,8 @@
use crate::codegen::{ use crate::codegen::{
llvm_intrinsics::call_int_umin, stmt::gen_for_callback_incrementing, CodeGenContext, // irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
CodeGenContext,
CodeGenerator, CodeGenerator,
}; };
use inkwell::context::Context; use inkwell::context::Context;
@ -1207,25 +1210,27 @@ impl<'ctx> NDArrayType<'ctx> {
ctx: &'ctx Context, ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>, dtype: BasicTypeEnum<'ctx>,
) -> Self { ) -> Self {
let llvm_usize = generator.get_size_type(ctx); todo!()
// struct NDArray { num_dims: size_t, dims: size_t*, data: T* } // let llvm_usize = generator.get_size_type(ctx);
//
// * num_dims: Number of dimensions in the array
// * dims: Pointer to an array containing the size of each dimension
// * data: Pointer to an array containing the array data
let llvm_ndarray = ctx
.struct_type(
&[
llvm_usize.into(),
llvm_usize.ptr_type(AddressSpace::default()).into(),
dtype.ptr_type(AddressSpace::default()).into(),
],
false,
)
.ptr_type(AddressSpace::default());
NDArrayType::from_type(llvm_ndarray, llvm_usize) // // struct NDArray { num_dims: size_t, dims: size_t*, data: T* }
// //
// // * num_dims: Number of dimensions in the array
// // * dims: Pointer to an array containing the size of each dimension
// // * data: Pointer to an array containing the array data
// let llvm_ndarray = ctx
// .struct_type(
// &[
// llvm_usize.into(),
// llvm_usize.ptr_type(AddressSpace::default()).into(),
// dtype.ptr_type(AddressSpace::default()).into(),
// ],
// false,
// )
// .ptr_type(AddressSpace::default());
// NDArrayType::from_type(llvm_ndarray, llvm_usize)
} }
/// Creates an [`NDArrayType`] from a [`PointerType`]. /// Creates an [`NDArrayType`] from a [`PointerType`].
@ -1659,23 +1664,22 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
indices: &Index, indices: &Index,
name: Option<&str>, name: Option<&str>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_elem_ty = indices
.ptr_offset(ctx, generator, &llvm_usize.const_zero(), None)
.get_type()
.get_element_type();
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected list[int32] but got {indices_elem_ty}")
};
assert_eq!(
indices_elem_ty.get_bit_width(),
32,
"Expected list[int32] but got list[int{}]",
indices_elem_ty.get_bit_width()
);
todo!() todo!()
// let llvm_usize = generator.get_size_type(ctx.ctx);
// let indices_elem_ty = indices
// .ptr_offset(ctx, generator, &llvm_usize.const_zero(), None)
// .get_type()
// .get_element_type();
// let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
// panic!("Expected list[int32] but got {indices_elem_ty}")
// };
// assert_eq!(
// indices_elem_ty.get_bit_width(),
// 32,
// "Expected list[int32] but got list[int{}]",
// indices_elem_ty.get_bit_width()
// );
// let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices); // let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
@ -1797,27 +1801,39 @@ struct StructFieldsBuilder<'ctx> {
} }
impl<'ctx> StructField<'ctx> { impl<'ctx> StructField<'ctx> {
/// TODO: DOCUMENT ME
pub fn gep( pub fn gep(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ptr: PointerValue<'ctx>, struct_ptr: PointerValue<'ctx>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
ctx.builder.build_struct_gep(ptr, self.gep_index, self.name).unwrap() let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to use i32 for GEP like that
unsafe {
ctx.builder
.build_in_bounds_gep(
struct_ptr,
&[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)],
self.name,
)
.unwrap()
}
} }
/// TODO: DOCUMENT ME
pub fn load( pub fn load(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ptr: PointerValue<'ctx>, struct_ptr: PointerValue<'ctx>,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
ctx.builder.build_load(self.gep(ctx, ptr), self.name).unwrap() ctx.builder.build_load(self.gep(ctx, struct_ptr), self.name).unwrap()
} }
pub fn store<V>(&self, ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>, value: V) /// TODO: DOCUMENT ME
pub fn store<V>(&self, ctx: &CodeGenContext<'ctx, '_>, struct_ptr: PointerValue<'ctx>, value: V)
where where
V: BasicValue<'ctx>, V: BasicValue<'ctx>,
{ {
ctx.builder.build_store(ptr, value).unwrap(); ctx.builder.build_store(self.gep(ctx, struct_ptr), value).unwrap();
} }
} }
@ -1856,7 +1872,7 @@ impl<'ctx> StructFields<'ctx> {
self.fields.len() as u32 self.fields.len() as u32
} }
pub fn as_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> { pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec(); let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec();
ctx.struct_type(llvm_fields.as_slice(), false) ctx.struct_type(llvm_fields.as_slice(), false)
} }
@ -1900,7 +1916,11 @@ impl<'ctx> StructFieldsBuilder<'ctx> {
fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> { fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> {
let index = self.gep_index_counter; let index = self.gep_index_counter;
self.gep_index_counter += 1; self.gep_index_counter += 1;
StructField { gep_index: index, name, ty }
let field = StructField { gep_index: index, name, ty };
self.fields.push(field); // Register into self.fields
field // Return to the caller to conveniently let them do whatever they want
} }
fn end(self) -> StructFields<'ctx> { fn end(self) -> StructFields<'ctx> {
@ -1931,8 +1951,8 @@ impl<'ctx> NpArrayType<'ctx> {
NpArrayType { size_type, elem_type: ctx.ctx.i8_type().as_basic_type_enum() } NpArrayType { size_type, elem_type: ctx.ctx.i8_type().as_basic_type_enum() }
} }
pub fn struct_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructType<'ctx> { pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
self.fields().whole_struct.as_struct_type(ctx.ctx) self.fields().whole_struct.get_struct_type(ctx)
} }
pub fn fields(&self) -> NpArrayStructFields<'ctx> { pub fn fields(&self) -> NpArrayStructFields<'ctx> {
@ -1958,29 +1978,43 @@ impl<'ctx> NpArrayType<'ctx> {
/// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`. /// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`.
/// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`, /// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`,
/// all with empty/uninitialized values. /// all with empty/uninitialized values.
pub fn alloca( pub fn var_alloc<G>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
in_ndims: IntValue<'ctx>, in_ndims: IntValue<'ctx>,
name: &str, name: Option<&str>,
) -> NpArrayValue<'ctx> { ) -> NpArrayValue<'ctx>
let fields = self.fields(); where
let ptr = G: CodeGenerator + ?Sized,
ctx.builder.build_alloca(fields.whole_struct.as_struct_type(ctx.ctx), name).unwrap(); {
let ptr = generator
.gen_var_alloc(ctx, self.get_struct_type(ctx.ctx).as_basic_type_enum(), name)
.unwrap();
// Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides` // Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides`
let allocated_shape = let allocated_shape = generator
ctx.builder.build_array_alloca(fields.shape.ty, in_ndims, "allocated_shape").unwrap(); .gen_array_var_alloc(
let allocated_strides = ctx ctx,
.builder self.size_type.as_basic_type_enum(),
.build_array_alloca(fields.strides.ty, in_ndims, "allocated_strides") in_ndims,
Some("allocated_shape"),
)
.unwrap();
let allocated_strides = generator
.gen_array_var_alloc(
ctx,
self.size_type.as_basic_type_enum(),
in_ndims,
Some("allocated_strides"),
)
.unwrap(); .unwrap();
let value = NpArrayValue { ty: *self, ptr }; let value = NpArrayValue { ty: *self, ptr };
value.store_ndims(ctx, in_ndims); value.store_ndims(ctx, in_ndims);
value.store_itemsize(ctx, self.elem_type.size_of().unwrap()); value.store_itemsize(ctx, self.elem_type.size_of().unwrap());
value.store_shape(ctx, allocated_shape); value.store_shape(ctx, allocated_shape.base_ptr(ctx, generator));
value.store_strides(ctx, allocated_strides); value.store_strides(ctx, allocated_strides.base_ptr(ctx, generator));
return value; return value;
} }
@ -2038,13 +2072,15 @@ impl<'ctx> NpArrayValue<'ctx> {
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// Get the pointer to `shape`
let field = self.ty.fields().shape; let field = self.ty.fields().shape;
field.gep(ctx, self.ptr); let shape = field.load(ctx, self.ptr).into_pointer_value();
// Load `ndims`
let ndims = self.load_ndims(ctx); let ndims = self.load_ndims(ctx);
TypedArrayLikeAdapter { TypedArrayLikeAdapter {
adapted: ArraySliceValue(self.ptr, ndims, Some(field.name)), adapted: ArraySliceValue(shape, ndims, Some(field.name)),
downcast_fn: Box::new(|_ctx, x| x.into_int_value()), downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()), upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
} }
@ -2055,13 +2091,15 @@ impl<'ctx> NpArrayValue<'ctx> {
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// Get the pointer to `strides`
let field = self.ty.fields().strides; let field = self.ty.fields().strides;
field.gep(ctx, self.ptr); let strides = field.load(ctx, self.ptr).into_pointer_value();
// Load `ndims`
let ndims = self.load_ndims(ctx); let ndims = self.load_ndims(ctx);
TypedArrayLikeAdapter { TypedArrayLikeAdapter {
adapted: ArraySliceValue(self.ptr, ndims, Some(field.name)), adapted: ArraySliceValue(strides, ndims, Some(field.name)),
downcast_fn: Box::new(|_ctx, x| x.into_int_value()), downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()), upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
} }

View File

@ -4,8 +4,8 @@ mod test;
use super::{ use super::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, NpArrayType, check_basic_types_match, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue,
NpArrayValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, NDArrayValue, NpArrayType, NpArrayValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
}, },
llvm_intrinsics, CodeGenContext, CodeGenerator, llvm_intrinsics, CodeGenContext, CodeGenerator,
}; };
@ -17,7 +17,7 @@ use inkwell::{
memory_buffer::MemoryBuffer, memory_buffer::MemoryBuffer,
module::Module, module::Module,
types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType}, types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType},
values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue}, values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
use itertools::Either; use itertools::Either;
@ -565,370 +565,370 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo
.unwrap() .unwrap()
} }
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the // /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size. // /// calculated total size.
/// // ///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. // /// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, // /// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively. // /// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>( // pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G, // generator: &G,
ctx: &CodeGenContext<'ctx, '_>, // ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims, // dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>), // (begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx> // ) -> IntValue<'ctx>
where // where
G: CodeGenerator + ?Sized, // G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>, // Dims: ArrayLikeIndexer<'ctx>,
{ // {
let llvm_usize = generator.get_size_type(ctx.ctx); // let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); // let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { // let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size", // 32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64", // 64 => "__nac3_ndarray_calc_size64",
bw => unreachable!("Unsupported size type bit width: {}", bw), // bw => unreachable!("Unsupported size type bit width: {}", bw),
}; // };
let ndarray_calc_size_fn_t = llvm_usize.fn_type( // let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()], // &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false, // false,
); // );
let ndarray_calc_size_fn = // let ndarray_calc_size_fn =
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { // ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) // ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
}); // });
//
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); // let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator)); // let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder // ctx.builder
.build_call( // .build_call(
ndarray_calc_size_fn, // ndarray_calc_size_fn,
&[ // &[
dims.base_ptr(ctx, generator).into(), // dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(), // dims.size(ctx, generator).into(),
begin.into(), // begin.into(),
end.into(), // end.into(),
], // ],
"", // "",
) // )
.map(CallSiteValue::try_as_basic_value) // .map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value)) // .map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left) // .map(Either::unwrap_left)
.unwrap() // .unwrap()
} // }
//
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] // /// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index. // /// containing `i32` indices of the flattened index.
/// // ///
/// * `index` - The index to compute the multidimensional index for. // /// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an // /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`. // /// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>( // pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G, // generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>, // ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>, // index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>, // ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { // ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type(); // let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type(); // let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); // let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); // let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); // let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { // let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices", // 32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64", // 64 => "__nac3_ndarray_calc_nd_indices64",
bw => unreachable!("Unsupported size type bit width: {}", bw), // bw => unreachable!("Unsupported size type bit width: {}", bw),
}; // };
let ndarray_calc_nd_indices_fn = // let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { // ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type( // let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], // &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false, // false,
); // );
//
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) // ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
}); // });
//
let ndarray_num_dims = ndarray.load_ndims(ctx); // let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes(); // let ndarray_dims = ndarray.dim_sizes();
//
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); // let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
//
ctx.builder // ctx.builder
.build_call( // .build_call(
ndarray_calc_nd_indices_fn, // ndarray_calc_nd_indices_fn,
&[ // &[
index.into(), // index.into(),
ndarray_dims.base_ptr(ctx, generator).into(), // ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(), // ndarray_num_dims.into(),
indices.into(), // indices.into(),
], // ],
"", // "",
) // )
.unwrap(); // .unwrap();
//
TypedArrayLikeAdapter::from( // TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), // ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()), // Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()), // Box::new(|_, v| v.into()),
) // )
} // }
//
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( // fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G, // generator: &G,
ctx: &CodeGenContext<'ctx, '_>, // ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, // ndarray: NDArrayValue<'ctx>,
indices: &Indices, // indices: &Indices,
) -> IntValue<'ctx> // ) -> IntValue<'ctx>
where // where
G: CodeGenerator + ?Sized, // G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>, // Indices: ArrayLikeIndexer<'ctx>,
{ // {
let llvm_i32 = ctx.ctx.i32_type(); // let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); // let llvm_usize = generator.get_size_type(ctx.ctx);
//
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); // let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); // let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
debug_assert_eq!( // debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator)) // IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width) // .map(IntType::get_bit_width)
.unwrap_or_default(), // .unwrap_or_default(),
llvm_i32.get_bit_width(), // llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`" // "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
); // );
debug_assert_eq!( // debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(), // indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(), // llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" // "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
); // );
//
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { // let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index", // 32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64", // 64 => "__nac3_ndarray_flatten_index64",
bw => unreachable!("Unsupported size type bit width: {}", bw), // bw => unreachable!("Unsupported size type bit width: {}", bw),
}; // };
let ndarray_flatten_index_fn = // let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { // ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type( // let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], // &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false, // false,
); // );
//
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) // ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
}); // });
//
let ndarray_num_dims = ndarray.load_ndims(ctx); // let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes(); // let ndarray_dims = ndarray.dim_sizes();
//
let index = ctx // let index = ctx
.builder // .builder
.build_call( // .build_call(
ndarray_flatten_index_fn, // ndarray_flatten_index_fn,
&[ // &[
ndarray_dims.base_ptr(ctx, generator).into(), // ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(), // ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(), // indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(), // indices.size(ctx, generator).into(),
], // ],
"", // "",
) // )
.map(CallSiteValue::try_as_basic_value) // .map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value)) // .map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left) // .map(Either::unwrap_left)
.unwrap(); // .unwrap();
//
index // index
} // }
//
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the // /// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index. // /// multidimensional index.
/// // ///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an // /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`. // /// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for. // /// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>( // pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G, // generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, // ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>, // ndarray: NDArrayValue<'ctx>,
indices: &Index, // indices: &Index,
) -> IntValue<'ctx> // ) -> IntValue<'ctx>
where // where
G: CodeGenerator + ?Sized, // G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>, // Index: ArrayLikeIndexer<'ctx>,
{ // {
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) // call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
} // }
//
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of // /// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`. // /// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( // pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G, // generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, // ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>, // lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>, // rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { // ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx); // let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); // let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { // let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast", // 32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64", // 64 => "__nac3_ndarray_calc_broadcast64",
bw => unreachable!("Unsupported size type bit width: {}", bw), // bw => unreachable!("Unsupported size type bit width: {}", bw),
}; // };
let ndarray_calc_broadcast_fn = // let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { // ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type( // let fn_type = llvm_usize.fn_type(
&[ // &[
llvm_pusize.into(), // llvm_pusize.into(),
llvm_usize.into(), // llvm_usize.into(),
llvm_pusize.into(), // llvm_pusize.into(),
llvm_usize.into(), // llvm_usize.into(),
llvm_pusize.into(), // llvm_pusize.into(),
], // ],
false, // false,
); // );
//
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) // ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
}); // });
//
let lhs_ndims = lhs.load_ndims(ctx); // let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx); // let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); // let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
//
gen_for_callback_incrementing( // gen_for_callback_incrementing(
generator, // generator,
ctx, // ctx,
llvm_usize.const_zero(), // llvm_usize.const_zero(),
(min_ndims, false), // (min_ndims, false),
|generator, ctx, _, idx| { // |generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap(); // let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe { // let (lhs_dim_sz, rhs_dim_sz) = unsafe {
( // (
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), // lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), // rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
) // )
}; // };
//
let llvm_usize_const_one = llvm_usize.const_int(1, false); // let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx // let lhs_eqz = ctx
.builder // .builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "") // .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap(); // .unwrap();
let rhs_eqz = ctx // let rhs_eqz = ctx
.builder // .builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") // .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap(); // .unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); // let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
//
let lhs_eq_rhs = ctx // let lhs_eq_rhs = ctx
.builder // .builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") // .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap(); // .unwrap();
//
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); // let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
//
ctx.make_assert( // ctx.make_assert(
generator, // generator,
is_compatible, // is_compatible,
"0:ValueError", // "0:ValueError",
"operands could not be broadcast together", // "operands could not be broadcast together",
[None, None, None], // [None, None, None],
ctx.current_loc, // ctx.current_loc,
); // );
//
Ok(()) // Ok(())
}, // },
llvm_usize.const_int(1, false), // llvm_usize.const_int(1, false),
) // )
.unwrap(); // .unwrap();
//
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); // let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); // let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx); // let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator); // let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx); // let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); // let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); // let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
//
ctx.builder // ctx.builder
.build_call( // .build_call(
ndarray_calc_broadcast_fn, // ndarray_calc_broadcast_fn,
&[ // &[
lhs_dims.into(), // lhs_dims.into(),
lhs_ndims.into(), // lhs_ndims.into(),
rhs_dims.into(), // rhs_dims.into(),
rhs_ndims.into(), // rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(), // out_dims.base_ptr(ctx, generator).into(),
], // ],
"", // "",
) // )
.unwrap(); // .unwrap();
//
TypedArrayLikeAdapter::from( // TypedArrayLikeAdapter::from(
out_dims, // out_dims,
Box::new(|_, v| v.into_int_value()), // Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()), // Box::new(|_, v| v.into()),
) // )
} // }
//
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] // /// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted // /// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`. // /// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index< // pub fn call_ndarray_calc_broadcast_index<
'ctx, // 'ctx,
G: CodeGenerator + ?Sized, // G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>, // BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>( // >(
generator: &mut G, // generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, // ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>, // array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx, // broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { // ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type(); // let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); // let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); // let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); // let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { // let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx", // 32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64", // 64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => unreachable!("Unsupported size type bit width: {}", bw), // bw => unreachable!("Unsupported size type bit width: {}", bw),
}; // };
let ndarray_calc_broadcast_fn = // let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { // ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type( // let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], // &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false, // false,
); // );
//
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) // ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
}); // });
//
let broadcast_size = broadcast_idx.size(ctx, generator); // let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); // let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
//
let array_dims = array.dim_sizes().base_ptr(ctx, generator); // let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx); // let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe { // let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) // broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
}; // };
//
ctx.builder // ctx.builder
.build_call( // .build_call(
ndarray_calc_broadcast_fn, // ndarray_calc_broadcast_fn,
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()], // &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"", // "",
) // )
.unwrap(); // .unwrap();
//
TypedArrayLikeAdapter::from( // TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), // ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()), // Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()), // Box::new(|_, v| v.into()),
) // )
} // }
fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant { fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant {
match ty.get_bit_width() { match ty.get_bit_width() {
@ -965,21 +965,28 @@ where
}) })
} }
fn get_ndarray_struct_ptr<'ctx>(ctx: &'ctx Context, size_type: IntType<'ctx>) -> PointerType<'ctx> { fn get_irrt_ndarray_ptr_type<'ctx>(
let i8_type = ctx.i8_type(); ctx: &CodeGenContext<'ctx, '_>,
size_type: IntType<'ctx>,
) -> PointerType<'ctx> {
let i8_type = ctx.ctx.i8_type();
let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() }; let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() };
let struct_ty = ndarray_ty.fields().whole_struct.as_struct_type(ctx); let struct_ty = ndarray_ty.get_struct_type(ctx.ctx);
struct_ty.ptr_type(AddressSpace::default()) struct_ty.ptr_type(AddressSpace::default())
} }
fn get_irrt_opaque_uint8_ptr_type<'ctx>(ctx: &CodeGenContext<'ctx, '_>) -> PointerType<'ctx> {
ctx.ctx.i8_type().ptr_type(AddressSpace::default())
}
pub fn call_nac3_ndarray_size<'ctx>( pub fn call_nac3_ndarray_size<'ctx>(
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
ndarray: NpArrayValue<'ctx>, ndarray: NpArrayValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let size_type = ndarray.ty.size_type; let size_type = ndarray.ty.size_type;
let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_size", || { let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_size", || {
size_type.fn_type(&[get_ndarray_struct_ptr(ctx.ctx, size_type).into()], false) size_type.fn_type(&[get_irrt_ndarray_ptr_type(ctx, size_type).into()], false)
}); });
ctx.builder ctx.builder
@ -989,3 +996,44 @@ pub fn call_nac3_ndarray_size<'ctx>(
.unwrap_left() .unwrap_left()
.into_int_value() .into_int_value()
} }
pub fn call_nac3_ndarray_fill_generic<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NpArrayValue<'ctx>,
fill_value: BasicValueEnum<'ctx>,
) {
// Sanity check on type of `fill_value`
check_basic_types_match(ndarray.ty.elem_type, fill_value.get_type().as_basic_type_enum())
.unwrap();
let size_type = ndarray.ty.size_type;
let function =
get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_fill_generic", || {
ctx.ctx.void_type().fn_type(
&[
get_irrt_ndarray_ptr_type(ctx, size_type).into(), // NDArray<SizeT>* ndarray
get_irrt_opaque_uint8_ptr_type(ctx).into(), // uint8_t* pvalue
],
false,
)
});
// Put `fill_value` onto the stack and get a pointer to it, and that pointer will be `pvalue`
let pvalue = ctx.builder.build_alloca(ndarray.ty.elem_type, "fill_value").unwrap();
ctx.builder.build_store(pvalue, fill_value).unwrap();
// Cast pvalue to `uint8_t*`
let pvalue = ctx.builder.build_pointer_cast(pvalue, get_irrt_opaque_uint8_ptr_type(ctx), "").unwrap();
// Call the IRRT function
ctx.builder
.build_call(
function,
&[
ndarray.ptr.into(), // ndarray
pvalue.into(), // pvalue
],
"",
)
.unwrap();
}

View File

@ -7,6 +7,7 @@ use crate::{
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
}, },
}; };
use classes::NpArrayType;
use crossbeam::channel::{unbounded, Receiver, Sender}; use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
@ -476,7 +477,11 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx, module, generator, unifier, top_level, type_cache, dtype, ctx, module, generator, unifier, top_level, type_cache, dtype,
); );
NDArrayType::new(generator, ctx, element_type).as_base_type().into() let ndarray_ty = NpArrayType {
size_type: generator.get_size_type(ctx),
elem_type: element_type,
};
ndarray_ty.get_struct_type(ctx).as_basic_type_enum()
} }
_ => unreachable!( _ => unreachable!(

View File

@ -2,15 +2,11 @@ use crate::{
codegen::{ codegen::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayType, NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayType, NDArrayValue,
ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, NpArrayType, ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
}, },
expr::gen_binop_expr_with_values, expr::gen_binop_expr_with_values,
irrt::{ irrt::call_nac3_ndarray_fill_generic,
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
call_ndarray_calc_size,
},
llvm_intrinsics::{self, call_memcpy_generic}, llvm_intrinsics::{self, call_memcpy_generic},
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
@ -26,7 +22,7 @@ use crate::{
typedef::{FunSignature, Type, TypeEnum}, typedef::{FunSignature, Type, TypeEnum},
}, },
}; };
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; use inkwell::types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType};
use inkwell::{ use inkwell::{
types::BasicType, types::BasicType,
values::{BasicValueEnum, IntValue, PointerValue}, values::{BasicValueEnum, IntValue, PointerValue},
@ -34,6 +30,8 @@ use inkwell::{
}; };
use nac3parser::ast::{Operator, StrRef}; use nac3parser::ast::{Operator, StrRef};
use super::{classes::NpArrayValue, stmt::gen_return};
// /// Creates an uninitialized `NDArray` instance. // /// Creates an uninitialized `NDArray` instance.
// fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( // fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G, // generator: &mut G,
@ -2015,3 +2013,335 @@ use nac3parser::ast::{Operator, StrRef};
// Ok(()) // Ok(())
// } // }
// //
fn simple_assert<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
cond: IntValue<'ctx>,
msg: &str,
) where
G: CodeGenerator + ?Sized,
{
let mut full_msg = String::from("simple_assert failed: ");
full_msg.push_str(msg);
ctx.make_assert(
generator,
cond,
"0:ValueError",
full_msg.as_str(),
[None, None, None],
ctx.current_loc,
);
}
fn copy_array_slice<'ctx, G, Src, Dst>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dst: Dst,
src: Src,
) where
G: CodeGenerator + ?Sized,
Dst: TypedArrayLikeMutator<'ctx, IntType<'ctx>>,
Src: TypedArrayLikeAccessor<'ctx, IntType<'ctx>>,
{
// Sanity check
let len_match = ctx
.builder
.build_int_compare(
IntPredicate::EQ,
src.size(ctx, generator),
dst.size(ctx, generator),
"len_match",
)
.unwrap();
simple_assert(generator, ctx, len_match, "copy_array_slice length mismatched");
let size_type = generator.get_size_type(ctx.ctx);
let init_val = size_type.const_zero();
let max_val = (dst.size(ctx, generator), false);
let incr_val = size_type.const_int(1, false);
gen_for_callback_incrementing(
generator,
ctx,
init_val,
max_val,
|generator, ctx, _hooks, idx| {
let value = src.get_typed(ctx, generator, &idx, Some("copy_array_slice.tmp"));
dst.set_typed(ctx, generator, &idx, value);
Ok(())
},
incr_val,
)
.unwrap();
}
pub fn alloca_ndarray_uninitialized<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_type: BasicTypeEnum<'ctx>,
ndims: IntValue<'ctx>,
name: Option<&str>,
) -> Result<NpArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
{
let size_type = generator.get_size_type(ctx.ctx);
let ndarray_ty = NpArrayType { size_type, elem_type };
let ndarray = ndarray_ty.var_alloc(generator, ctx, ndims, name);
Ok(ndarray)
}
pub struct Producer<'ctx, G: CodeGenerator + ?Sized, T> {
pub count: IntValue<'ctx>,
pub write_to_slice: Box<
dyn Fn(
&mut G,
&mut CodeGenContext<'ctx, '_>,
&TypedArrayLikeAdapter<'ctx, T>,
) -> Result<(), String>
+ 'ctx,
>,
}
/// TODO: UPDATE DOCUMENTATION
/// LLVM-typed implementation for generating a [`Producer`] that sets a list of ints.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
///
/// ### Notes on `shape`
///
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
///
/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to
/// learn how `shape` gets from being a Python user expression to here.
fn parse_input_shape_arg<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
) -> Result<Producer<'ctx, G, IntValue<'ctx>>, String>
where
G: CodeGenerator + ?Sized,
{
let size_type = generator.get_size_type(ctx.ctx);
match &*ctx.unifier.get_ty(shape_ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
// A list has to be a PointerValue
let shape_list = ListValue::from_ptr_val(shape.into_pointer_value(), size_type, None);
// Create `Producer`
let ndims = shape_list.load_size(ctx, Some("count"));
Ok(Producer {
count: ndims,
write_to_slice: Box::new(move |ctx, generator, dst_slice| {
// Basically iterate through the list and write to `dst_slice` accordingly
let init_val = size_type.const_zero();
let max_val = (ndims, false);
let incr_val = size_type.const_int(1, false);
gen_for_callback_incrementing(
ctx,
generator,
init_val,
max_val,
|generator, ctx, _hooks, idx| {
// Get the dimension at `idx`
let dim =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
// Cast `dim` to SizeT
let dim = ctx
.builder
.build_int_s_extend_or_bit_cast(dim, size_type, "dim_casted")
.unwrap();
// Write
dst_slice.set_typed(ctx, generator, &idx, dim);
Ok(())
},
incr_val,
)?;
Ok(())
}),
})
}
TypeEnum::TTuple { ty: tuple_types } => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
let ndims = tuple_types.len();
// A tuple has to be a StructValue
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
let shape_tuple = shape.into_struct_value();
Ok(Producer {
count: size_type.const_int(ndims as u64, false),
write_to_slice: Box::new(move |generator, ctx, dst_slice| {
for dim_i in 0..ndims {
// Get the dimension at `dim_i`
let dim = ctx
.builder
.build_extract_value(
shape_tuple,
dim_i as u32,
format!("dim{dim_i}").as_str(),
)
.unwrap()
.into_int_value();
// Cast `dim` to SizeT
let dim = ctx
.builder
.build_int_s_extend_or_bit_cast(dim, size_type, "dim_casted")
.unwrap();
// Write
dst_slice.set_typed(
ctx,
generator,
&size_type.const_int(dim_i as u64, false),
dim,
);
}
Ok(())
}),
})
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
{
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
// The value has to be an integer
let shape_int = shape.into_int_value();
Ok(Producer {
count: size_type.const_int(1, false),
write_to_slice: Box::new(move |generator, ctx, dst_slice| {
// Only index 0 is set with the input value
let dim_i = size_type.const_zero();
// Cast `shape_int` to SizeT
let dim = ctx
.builder
.build_int_s_extend_or_bit_cast(shape_int, size_type, "dim_casted")
.unwrap();
// Write
dst_slice.set_typed(ctx, generator, &dim_i, dim);
Ok(())
}),
})
}
_ => panic!("parse_input_shape_arg encountered unknown type"),
}
}
/// TODO: DOCUMENT ME
fn alloca_ndarray_uninitialized_shaped<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_type: BasicTypeEnum<'ctx>,
shape_producer: Producer<'ctx, G, IntValue<'ctx>>,
name: Option<&str>,
) -> Result<NpArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
{
// Allocate an uninitialized ndarray
let ndims = shape_producer.count;
let ndarray = alloca_ndarray_uninitialized(generator, ctx, elem_type, ndims, name)?;
// Fill `ndarray.shape` with `shape_producer`
(shape_producer.write_to_slice)(generator, ctx, &ndarray.shape_slice(ctx))?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for constructing an empty `NDArray`.
fn call_ndarray_empty_impl<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
name: Option<&str>,
) -> Result<NpArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
{
let elem_type = ctx.get_llvm_type(generator, elem_ty);
let shape_producer = parse_input_shape_arg(generator, ctx, shape, shape_ty)?;
alloca_ndarray_uninitialized_shaped(generator, ctx, elem_type, shape_producer, name)
}
/// Generates LLVM IR for `np.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
let ndarray = call_ndarray_empty_impl(
generator,
context,
context.primitives.float,
shape,
shape_ty,
None,
)?;
Ok(ndarray.ptr)
}
/// Generates LLVM IR for `np.zeros`.
///
/// NOTE: Current `dtype` is always `float64`.
pub fn gen_ndarray_zeros<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Allocate an ndarray and fill it later
let ndarray = call_ndarray_empty_impl(
generator,
context,
context.primitives.float, // float64
shape,
shape_ty,
None,
)?;
// TRICK: The float64 type could be conveniently extracted out of `ndarray`
let float_type = ndarray.ty.elem_type.into_float_type();
// Fill the ndarray
call_nac3_ndarray_fill_generic(context, ndarray, float_type.const_float(1.0).into());
// Return our ndarray
println!("ndarray.ptr = {}", ndarray.ptr);
Ok(ndarray.ptr)
}

View File

@ -1193,14 +1193,13 @@ impl<'a> BuiltinBuilder<'a> {
self.ndarray_float, self.ndarray_float,
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], &[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, obj, fun, args, generator| { Box::new(move |ctx, obj, fun, args, generator| {
todo!() let func = match prim {
// let func = match prim { PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
// PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty, PrimDef::FunNpZeros => gen_ndarray_zeros,
// PrimDef::FunNpZeros => gen_ndarray_zeros, PrimDef::FunNpOnes => todo!(), // gen_ndarray_ones,
// PrimDef::FunNpOnes => gen_ndarray_ones, _ => unreachable!(),
// _ => unreachable!(), };
// }; func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
// func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
}), }),
) )
} }

5
nac3core/src/util.rs Normal file
View File

@ -0,0 +1,5 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SizeVariant {
Bits32,
Bits64,
}

View File

@ -0,0 +1,3 @@
def run() -> int32:
hello = np_zeros((3, 4))
return 0

View File

@ -449,6 +449,9 @@ fn main() {
.create_target_machine(llvm_options.opt_level) .create_target_machine(llvm_options.opt_level)
.expect("couldn't create target machine"); .expect("couldn't create target machine");
// NOTE: DEBUG PRINT
main.print_to_file("standalone.ll").unwrap();
let pass_options = PassBuilderOptions::create(); let pass_options = PassBuilderOptions::create();
pass_options.set_merge_functions(true); pass_options.set_merge_functions(true);
let passes = format!("default<O{}>", opt_level as u32); let passes = format!("default<O{}>", opt_level as u32);