1
0
forked from M-Labs/nac3

Compare commits

..

9 Commits

Author SHA1 Message Date
87d2a4ed59 WIP 2024-07-10 17:27:10 +08:00
9aae290727 core: irrt general numpy broadcasting 2024-07-10 17:05:01 +08:00
d18c769cdc core: irrt general numpy slicing 2024-07-10 14:05:08 +08:00
f41f06aec7 core: more irrt 2024-07-10 11:56:31 +08:00
1303265785 core: build.rs rewrite regex to capture = type 2024-07-10 10:17:45 +08:00
e9cf6ce1e5 core: move irrt c++ sources to /nac3core/irrt 2024-07-10 10:17:45 +08:00
bc91ab9b13 core: IRRT -Werror=return-type 2024-07-10 10:17:43 +08:00
1e06a3d199 core: add irrt_test 2024-07-10 10:11:07 +08:00
87511ac749 core: comment out numpy 2024-07-10 10:05:07 +08:00
12 changed files with 566 additions and 989 deletions

View File

@ -212,11 +212,11 @@ namespace {
return this->size() * itemsize;
}
void set_pelement_value(uint8_t* pelement, const uint8_t* pvalue) {
void set_value_at_pelement(uint8_t* pelement, const uint8_t* pvalue) {
__builtin_memcpy(pelement, pvalue, itemsize);
}
uint8_t* get_pelement_by_indices(const SizeT *indices) {
uint8_t* get_pelement(const SizeT *indices) {
uint8_t* element = data;
for (SizeT dim_i = 0; dim_i < ndims; dim_i++)
element += indices[dim_i] * strides[dim_i];
@ -229,7 +229,7 @@ namespace {
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims);
ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth);
return get_pelement_by_indices(indices);
return get_pelement(indices);
}
// Get pointer to the first element of this ndarray, assuming
@ -259,8 +259,8 @@ namespace {
iter.set_indices_zero();
for (SizeT i = 0; i < this->size(); i++, iter.next()) {
uint8_t* pelement = get_pelement_by_indices(iter.indices);
set_pelement_value(pelement, pvalue);
uint8_t* pelement = get_pelement(iter.indices);
set_value_at_pelement(pelement, pvalue);
}
}
@ -283,8 +283,8 @@ namespace {
if (!in_bounds(indices)) continue;
uint8_t* pelement = get_pelement_by_indices(indices);
set_pelement_value(pelement, one_pvalue);
uint8_t* pelement = get_pelement(indices);
set_value_at_pelement(pelement, one_pvalue);
}
}
@ -435,9 +435,9 @@ namespace {
};
const SizeT this_size = this->size();
for (SizeT i = 0; i < this_size; i++, iter.next()) {
uint8_t* src_pelement = broadcasted_src_ndarray_strides->get_pelement_by_indices(indices);
uint8_t* this_pelement = this->get_pelement_by_indices(indices);
this->set_pelement_value(src_pelement, src_pelement);
uint8_t* src_pelement = broadcasted_src_ndarray_strides->get_pelement(indices);
uint8_t* this_pelement = this->get_pelement(indices);
this->set_value_at_pelement(src_pelement, src_pelement);
}
}
};

View File

@ -81,7 +81,7 @@ void __print_ndarray_aux(const char *format, bool first, bool last, SizeT* curso
SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims);
for (SizeT i = 0; i < dim; i++) {
ndarray_util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor);
ElementT* pelement = (ElementT*) ndarray->get_pelement_by_indices(indices);
ElementT* pelement = (ElementT*) ndarray->get_pelement(indices);
ElementT element = *pelement;
if (i != 0) printf(", "); // List delimiter
@ -394,10 +394,10 @@ void test_ndslice_1() {
assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape);
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
assert_values_match("dst_ndarray[0, 0]", "%f", 5.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0, 0 })));
assert_values_match("dst_ndarray[0, 1]", "%f", 7.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0, 1 })));
assert_values_match("dst_ndarray[1, 0]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1, 0 })));
assert_values_match("dst_ndarray[1, 1]", "%f", 11.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1, 1 })));
assert_values_match("dst_ndarray[0, 0]", "%f", 5.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 0 })));
assert_values_match("dst_ndarray[0, 1]", "%f", 7.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 1 })));
assert_values_match("dst_ndarray[1, 0]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 0 })));
assert_values_match("dst_ndarray[1, 1]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 1 })));
}
void test_ndslice_2() {
@ -471,8 +471,8 @@ void test_ndslice_2() {
assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides);
// [5.0, 3.0]
assert_values_match("dst_ndarray[0]", "%f", 11.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0 })));
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1 })));
assert_values_match("dst_ndarray[0]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0 })));
assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1 })));
}
void test_can_broadcast_shape() {
@ -618,15 +618,24 @@ void test_ndarray_broadcast_1() {
assert_arrays_match("dst_ndarray->strides", "%d", dst_ndims, (int32_t[]) { 0, 0, 8 }, dst_ndarray.strides);
assert_values_match("dst_ndarray[0, 0, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 0})));
assert_values_match("dst_ndarray[0, 0, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 1})));
assert_values_match("dst_ndarray[0, 0, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 2})));
assert_values_match("dst_ndarray[0, 0, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 3})));
assert_values_match("dst_ndarray[0, 1, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 0})));
assert_values_match("dst_ndarray[0, 1, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 1})));
assert_values_match("dst_ndarray[0, 1, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 2})));
assert_values_match("dst_ndarray[0, 1, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 3})));
assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {1, 2, 3})));
assert_values_match("dst_ndarray[0, 0, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 0})));
assert_values_match("dst_ndarray[0, 0, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 1})));
assert_values_match("dst_ndarray[0, 0, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 2})));
assert_values_match("dst_ndarray[0, 0, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 3})));
assert_values_match("dst_ndarray[0, 1, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 0})));
assert_values_match("dst_ndarray[0, 1, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 1})));
assert_values_match("dst_ndarray[0, 1, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 2})));
assert_values_match("dst_ndarray[0, 1, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 3})));
assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {1, 2, 3})));
}
void test_assign_with() {
/*
```
xs = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=np.float64)
ys = xs.shape
```
*/
}
int main() {
@ -644,5 +653,6 @@ int main() {
test_ndslice_2();
test_can_broadcast_shape();
test_ndarray_broadcast_1();
test_assign_with();
return 0;
}

View File

@ -702,54 +702,53 @@ pub fn call_numpy_min<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
todo!()
// let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
// let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
// let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
// let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
// if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
// let n_sz_eqz = ctx
// .builder
// .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
// .unwrap();
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx
.builder
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
.unwrap();
// ctx.make_assert(
// generator,
// n_sz_eqz,
// "0:ValueError",
// "zero-size array to reduction operation minimum which has no identity",
// [None, None, None],
// ctx.current_loc,
// );
// }
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
"zero-size array to reduction operation minimum which has no identity",
[None, None, None],
ctx.current_loc,
);
}
// let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
// unsafe {
// let identity =
// n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
// ctx.builder.build_store(accumulator_addr, identity).unwrap();
// }
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
unsafe {
let identity =
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap();
}
// gen_for_callback_incrementing(
// generator,
// ctx,
// llvm_usize.const_int(1, false),
// (n_sz, false),
// |generator, ctx, _, idx| {
// let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(1, false),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
// let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
// let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem));
// ctx.builder.build_store(accumulator_addr, result).unwrap();
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_min(ctx, (elem_ty, accumulator), (elem_ty, elem));
ctx.builder.build_store(accumulator_addr, result).unwrap();
// Ok(())
// },
// llvm_usize.const_int(1, false),
// )?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
// let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
// accumulator
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),
@ -921,54 +920,53 @@ pub fn call_numpy_max<'ctx, G: CodeGenerator + ?Sized>(
BasicValueEnum::PointerValue(n)
if a_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) =>
{
todo!()
// let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
// let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, a_ty);
let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty);
// let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
// let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
// if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
// let n_sz_eqz = ctx
// .builder
// .build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
// .unwrap();
let n = NDArrayValue::from_ptr_val(n, llvm_usize, None);
let n_sz = irrt::call_ndarray_calc_size(generator, ctx, &n.dim_sizes(), (None, None));
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
let n_sz_eqz = ctx
.builder
.build_int_compare(IntPredicate::NE, n_sz, n_sz.get_type().const_zero(), "")
.unwrap();
// ctx.make_assert(
// generator,
// n_sz_eqz,
// "0:ValueError",
// "zero-size array to reduction operation minimum which has no identity",
// [None, None, None],
// ctx.current_loc,
// );
// }
ctx.make_assert(
generator,
n_sz_eqz,
"0:ValueError",
"zero-size array to reduction operation minimum which has no identity",
[None, None, None],
ctx.current_loc,
);
}
// let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
// unsafe {
// let identity =
// n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
// ctx.builder.build_store(accumulator_addr, identity).unwrap();
// }
let accumulator_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
unsafe {
let identity =
n.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
ctx.builder.build_store(accumulator_addr, identity).unwrap();
}
// gen_for_callback_incrementing(
// generator,
// ctx,
// llvm_usize.const_int(1, false),
// (n_sz, false),
// |generator, ctx, _, idx| {
// let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_int(1, false),
(n_sz, false),
|generator, ctx, _, idx| {
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
// let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
// let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem));
// ctx.builder.build_store(accumulator_addr, result).unwrap();
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
let result = call_max(ctx, (elem_ty, accumulator), (elem_ty, elem));
ctx.builder.build_store(accumulator_addr, result).unwrap();
// Ok(())
// },
// llvm_usize.const_int(1, false),
// )?;
Ok(())
},
llvm_usize.const_int(1, false),
)?;
// let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
// accumulator
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
accumulator
}
_ => unsupported_type(ctx, FN_NAME, &[a_ty]),

View File

@ -1,8 +1,5 @@
use crate::codegen::{
// irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
CodeGenContext,
llvm_intrinsics::call_int_umin, stmt::gen_for_callback_incrementing, CodeGenContext,
CodeGenerator,
};
use inkwell::context::Context;
@ -1210,27 +1207,25 @@ impl<'ctx> NDArrayType<'ctx> {
ctx: &'ctx Context,
dtype: BasicTypeEnum<'ctx>,
) -> Self {
todo!()
let llvm_usize = generator.get_size_type(ctx);
// let llvm_usize = generator.get_size_type(ctx);
// struct NDArray { num_dims: size_t, dims: size_t*, data: T* }
//
// * num_dims: Number of dimensions in the array
// * dims: Pointer to an array containing the size of each dimension
// * data: Pointer to an array containing the array data
let llvm_ndarray = ctx
.struct_type(
&[
llvm_usize.into(),
llvm_usize.ptr_type(AddressSpace::default()).into(),
dtype.ptr_type(AddressSpace::default()).into(),
],
false,
)
.ptr_type(AddressSpace::default());
// // struct NDArray { num_dims: size_t, dims: size_t*, data: T* }
// //
// // * num_dims: Number of dimensions in the array
// // * dims: Pointer to an array containing the size of each dimension
// // * data: Pointer to an array containing the array data
// let llvm_ndarray = ctx
// .struct_type(
// &[
// llvm_usize.into(),
// llvm_usize.ptr_type(AddressSpace::default()).into(),
// dtype.ptr_type(AddressSpace::default()).into(),
// ],
// false,
// )
// .ptr_type(AddressSpace::default());
// NDArrayType::from_type(llvm_ndarray, llvm_usize)
NDArrayType::from_type(llvm_ndarray, llvm_usize)
}
/// Creates an [`NDArrayType`] from a [`PointerType`].
@ -1664,22 +1659,23 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
indices: &Index,
name: Option<&str>,
) -> PointerValue<'ctx> {
todo!()
// let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_usize = generator.get_size_type(ctx.ctx);
// let indices_elem_ty = indices
// .ptr_offset(ctx, generator, &llvm_usize.const_zero(), None)
// .get_type()
// .get_element_type();
// let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
// panic!("Expected list[int32] but got {indices_elem_ty}")
// };
// assert_eq!(
// indices_elem_ty.get_bit_width(),
// 32,
// "Expected list[int32] but got list[int{}]",
// indices_elem_ty.get_bit_width()
// );
let indices_elem_ty = indices
.ptr_offset(ctx, generator, &llvm_usize.const_zero(), None)
.get_type()
.get_element_type();
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected list[int32] but got {indices_elem_ty}")
};
assert_eq!(
indices_elem_ty.get_bit_width(),
32,
"Expected list[int32] but got list[int{}]",
indices_elem_ty.get_bit_width()
);
todo!()
// let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
@ -1801,39 +1797,27 @@ struct StructFieldsBuilder<'ctx> {
}
impl<'ctx> StructField<'ctx> {
/// TODO: DOCUMENT ME
pub fn gep(
&self,
ctx: &CodeGenContext<'ctx, '_>,
struct_ptr: PointerValue<'ctx>,
ptr: PointerValue<'ctx>,
) -> PointerValue<'ctx> {
let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to use i32 for GEP like that
unsafe {
ctx.builder
.build_in_bounds_gep(
struct_ptr,
&[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)],
self.name,
)
.unwrap()
}
ctx.builder.build_struct_gep(ptr, self.gep_index, self.name).unwrap()
}
/// TODO: DOCUMENT ME
pub fn load(
&self,
ctx: &CodeGenContext<'ctx, '_>,
struct_ptr: PointerValue<'ctx>,
ptr: PointerValue<'ctx>,
) -> BasicValueEnum<'ctx> {
ctx.builder.build_load(self.gep(ctx, struct_ptr), self.name).unwrap()
ctx.builder.build_load(self.gep(ctx, ptr), self.name).unwrap()
}
/// TODO: DOCUMENT ME
pub fn store<V>(&self, ctx: &CodeGenContext<'ctx, '_>, struct_ptr: PointerValue<'ctx>, value: V)
pub fn store<V>(&self, ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue<'ctx>, value: V)
where
V: BasicValue<'ctx>,
{
ctx.builder.build_store(self.gep(ctx, struct_ptr), value).unwrap();
ctx.builder.build_store(ptr, value).unwrap();
}
}
@ -1872,7 +1856,7 @@ impl<'ctx> StructFields<'ctx> {
self.fields.len() as u32
}
pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
pub fn as_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec();
ctx.struct_type(llvm_fields.as_slice(), false)
}
@ -1916,11 +1900,7 @@ impl<'ctx> StructFieldsBuilder<'ctx> {
fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> {
let index = self.gep_index_counter;
self.gep_index_counter += 1;
let field = StructField { gep_index: index, name, ty };
self.fields.push(field); // Register into self.fields
field // Return to the caller to conveniently let them do whatever they want
StructField { gep_index: index, name, ty }
}
fn end(self) -> StructFields<'ctx> {
@ -1951,8 +1931,8 @@ impl<'ctx> NpArrayType<'ctx> {
NpArrayType { size_type, elem_type: ctx.ctx.i8_type().as_basic_type_enum() }
}
pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
self.fields().whole_struct.get_struct_type(ctx)
pub fn struct_type(&self, ctx: &CodeGenContext<'ctx, '_>) -> StructType<'ctx> {
self.fields().whole_struct.as_struct_type(ctx.ctx)
}
pub fn fields(&self) -> NpArrayStructFields<'ctx> {
@ -1978,43 +1958,29 @@ impl<'ctx> NpArrayType<'ctx> {
/// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`.
/// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`,
/// all with empty/uninitialized values.
pub fn var_alloc<G>(
pub fn alloca(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ctx: &CodeGenContext<'ctx, '_>,
in_ndims: IntValue<'ctx>,
name: Option<&str>,
) -> NpArrayValue<'ctx>
where
G: CodeGenerator + ?Sized,
{
let ptr = generator
.gen_var_alloc(ctx, self.get_struct_type(ctx.ctx).as_basic_type_enum(), name)
.unwrap();
name: &str,
) -> NpArrayValue<'ctx> {
let fields = self.fields();
let ptr =
ctx.builder.build_alloca(fields.whole_struct.as_struct_type(ctx.ctx), name).unwrap();
// Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides`
let allocated_shape = generator
.gen_array_var_alloc(
ctx,
self.size_type.as_basic_type_enum(),
in_ndims,
Some("allocated_shape"),
)
.unwrap();
let allocated_strides = generator
.gen_array_var_alloc(
ctx,
self.size_type.as_basic_type_enum(),
in_ndims,
Some("allocated_strides"),
)
let allocated_shape =
ctx.builder.build_array_alloca(fields.shape.ty, in_ndims, "allocated_shape").unwrap();
let allocated_strides = ctx
.builder
.build_array_alloca(fields.strides.ty, in_ndims, "allocated_strides")
.unwrap();
let value = NpArrayValue { ty: *self, ptr };
value.store_ndims(ctx, in_ndims);
value.store_itemsize(ctx, self.elem_type.size_of().unwrap());
value.store_shape(ctx, allocated_shape.base_ptr(ctx, generator));
value.store_strides(ctx, allocated_strides.base_ptr(ctx, generator));
value.store_shape(ctx, allocated_shape);
value.store_strides(ctx, allocated_strides);
return value;
}
@ -2072,15 +2038,13 @@ impl<'ctx> NpArrayValue<'ctx> {
&self,
ctx: &CodeGenContext<'ctx, '_>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// Get the pointer to `shape`
let field = self.ty.fields().shape;
let shape = field.load(ctx, self.ptr).into_pointer_value();
field.gep(ctx, self.ptr);
// Load `ndims`
let ndims = self.load_ndims(ctx);
TypedArrayLikeAdapter {
adapted: ArraySliceValue(shape, ndims, Some(field.name)),
adapted: ArraySliceValue(self.ptr, ndims, Some(field.name)),
downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
}
@ -2091,15 +2055,13 @@ impl<'ctx> NpArrayValue<'ctx> {
&self,
ctx: &CodeGenContext<'ctx, '_>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// Get the pointer to `strides`
let field = self.ty.fields().strides;
let strides = field.load(ctx, self.ptr).into_pointer_value();
field.gep(ctx, self.ptr);
// Load `ndims`
let ndims = self.load_ndims(ctx);
TypedArrayLikeAdapter {
adapted: ArraySliceValue(strides, ndims, Some(field.name)),
adapted: ArraySliceValue(self.ptr, ndims, Some(field.name)),
downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
}

View File

@ -4,8 +4,8 @@ mod test;
use super::{
classes::{
check_basic_types_match, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue,
NDArrayValue, NpArrayType, NpArrayValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, NpArrayType,
NpArrayValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
llvm_intrinsics, CodeGenContext, CodeGenerator,
};
@ -17,7 +17,7 @@ use inkwell::{
memory_buffer::MemoryBuffer,
module::Module,
types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType},
values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue},
values::{BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue},
AddressSpace, IntPredicate,
};
use itertools::Either;
@ -565,370 +565,370 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo
.unwrap()
}
// /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
// /// calculated total size.
// ///
// /// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
// /// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
// /// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
// pub fn call_ndarray_calc_size<'ctx, G, Dims>(
// generator: &G,
// ctx: &CodeGenContext<'ctx, '_>,
// dims: &Dims,
// (begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
// ) -> IntValue<'ctx>
// where
// G: CodeGenerator + ?Sized,
// Dims: ArrayLikeIndexer<'ctx>,
// {
// let llvm_usize = generator.get_size_type(ctx.ctx);
// let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
// let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
// 32 => "__nac3_ndarray_calc_size",
// 64 => "__nac3_ndarray_calc_size64",
// bw => unreachable!("Unsupported size type bit width: {}", bw),
// };
// let ndarray_calc_size_fn_t = llvm_usize.fn_type(
// &[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
// false,
// );
// let ndarray_calc_size_fn =
// ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
// ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
// });
//
// let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
// let end = end.unwrap_or_else(|| dims.size(ctx, generator));
// ctx.builder
// .build_call(
// ndarray_calc_size_fn,
// &[
// dims.base_ptr(ctx, generator).into(),
// dims.size(ctx, generator).into(),
// begin.into(),
// end.into(),
// ],
// "",
// )
// .map(CallSiteValue::try_as_basic_value)
// .map(|v| v.map_left(BasicValueEnum::into_int_value))
// .map(Either::unwrap_left)
// .unwrap()
// }
//
// /// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
// /// containing `i32` indices of the flattened index.
// ///
// /// * `index` - The index to compute the multidimensional index for.
// /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
// /// `NDArray`.
// pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
// generator: &G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// index: IntValue<'ctx>,
// ndarray: NDArrayValue<'ctx>,
// ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// let llvm_void = ctx.ctx.void_type();
// let llvm_i32 = ctx.ctx.i32_type();
// let llvm_usize = generator.get_size_type(ctx.ctx);
// let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
// let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
// let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
// 32 => "__nac3_ndarray_calc_nd_indices",
// 64 => "__nac3_ndarray_calc_nd_indices64",
// bw => unreachable!("Unsupported size type bit width: {}", bw),
// };
// let ndarray_calc_nd_indices_fn =
// ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
// let fn_type = llvm_void.fn_type(
// &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
// false,
// );
//
// ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
// });
//
// let ndarray_num_dims = ndarray.load_ndims(ctx);
// let ndarray_dims = ndarray.dim_sizes();
//
// let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
//
// ctx.builder
// .build_call(
// ndarray_calc_nd_indices_fn,
// &[
// index.into(),
// ndarray_dims.base_ptr(ctx, generator).into(),
// ndarray_num_dims.into(),
// indices.into(),
// ],
// "",
// )
// .unwrap();
//
// TypedArrayLikeAdapter::from(
// ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
// Box::new(|_, v| v.into_int_value()),
// Box::new(|_, v| v.into()),
// )
// }
//
// fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
// generator: &G,
// ctx: &CodeGenContext<'ctx, '_>,
// ndarray: NDArrayValue<'ctx>,
// indices: &Indices,
// ) -> IntValue<'ctx>
// where
// G: CodeGenerator + ?Sized,
// Indices: ArrayLikeIndexer<'ctx>,
// {
// let llvm_i32 = ctx.ctx.i32_type();
// let llvm_usize = generator.get_size_type(ctx.ctx);
//
// let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
// let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
// debug_assert_eq!(
// IntType::try_from(indices.element_type(ctx, generator))
// .map(IntType::get_bit_width)
// .unwrap_or_default(),
// llvm_i32.get_bit_width(),
// "Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
// );
// debug_assert_eq!(
// indices.size(ctx, generator).get_type().get_bit_width(),
// llvm_usize.get_bit_width(),
// "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
// );
//
// let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
// 32 => "__nac3_ndarray_flatten_index",
// 64 => "__nac3_ndarray_flatten_index64",
// bw => unreachable!("Unsupported size type bit width: {}", bw),
// };
// let ndarray_flatten_index_fn =
// ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
// let fn_type = llvm_usize.fn_type(
// &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
// false,
// );
//
// ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
// });
//
// let ndarray_num_dims = ndarray.load_ndims(ctx);
// let ndarray_dims = ndarray.dim_sizes();
//
// let index = ctx
// .builder
// .build_call(
// ndarray_flatten_index_fn,
// &[
// ndarray_dims.base_ptr(ctx, generator).into(),
// ndarray_num_dims.into(),
// indices.base_ptr(ctx, generator).into(),
// indices.size(ctx, generator).into(),
// ],
// "",
// )
// .map(CallSiteValue::try_as_basic_value)
// .map(|v| v.map_left(BasicValueEnum::into_int_value))
// .map(Either::unwrap_left)
// .unwrap();
//
// index
// }
//
// /// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
// /// multidimensional index.
// ///
// /// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
// /// `NDArray`.
// /// * `indices` - The multidimensional index to compute the flattened index for.
// pub fn call_ndarray_flatten_index<'ctx, G, Index>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// ndarray: NDArrayValue<'ctx>,
// indices: &Index,
// ) -> IntValue<'ctx>
// where
// G: CodeGenerator + ?Sized,
// Index: ArrayLikeIndexer<'ctx>,
// {
// call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
// }
//
// /// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
// /// dimension and size of each dimension of the resultant `ndarray`.
// pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// lhs: NDArrayValue<'ctx>,
// rhs: NDArrayValue<'ctx>,
// ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// let llvm_usize = generator.get_size_type(ctx.ctx);
// let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
// let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
// 32 => "__nac3_ndarray_calc_broadcast",
// 64 => "__nac3_ndarray_calc_broadcast64",
// bw => unreachable!("Unsupported size type bit width: {}", bw),
// };
// let ndarray_calc_broadcast_fn =
// ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
// let fn_type = llvm_usize.fn_type(
// &[
// llvm_pusize.into(),
// llvm_usize.into(),
// llvm_pusize.into(),
// llvm_usize.into(),
// llvm_pusize.into(),
// ],
// false,
// );
//
// ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
// });
//
// let lhs_ndims = lhs.load_ndims(ctx);
// let rhs_ndims = rhs.load_ndims(ctx);
// let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
//
// gen_for_callback_incrementing(
// generator,
// ctx,
// llvm_usize.const_zero(),
// (min_ndims, false),
// |generator, ctx, _, idx| {
// let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
// let (lhs_dim_sz, rhs_dim_sz) = unsafe {
// (
// lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
// rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
// )
// };
//
// let llvm_usize_const_one = llvm_usize.const_int(1, false);
// let lhs_eqz = ctx
// .builder
// .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
// .unwrap();
// let rhs_eqz = ctx
// .builder
// .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
// .unwrap();
// let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
//
// let lhs_eq_rhs = ctx
// .builder
// .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
// .unwrap();
//
// let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
//
// ctx.make_assert(
// generator,
// is_compatible,
// "0:ValueError",
// "operands could not be broadcast together",
// [None, None, None],
// ctx.current_loc,
// );
//
// Ok(())
// },
// llvm_usize.const_int(1, false),
// )
// .unwrap();
//
// let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
// let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
// let lhs_ndims = lhs.load_ndims(ctx);
// let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
// let rhs_ndims = rhs.load_ndims(ctx);
// let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
// let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
//
// ctx.builder
// .build_call(
// ndarray_calc_broadcast_fn,
// &[
// lhs_dims.into(),
// lhs_ndims.into(),
// rhs_dims.into(),
// rhs_ndims.into(),
// out_dims.base_ptr(ctx, generator).into(),
// ],
// "",
// )
// .unwrap();
//
// TypedArrayLikeAdapter::from(
// out_dims,
// Box::new(|_, v| v.into_int_value()),
// Box::new(|_, v| v.into()),
// )
// }
//
// /// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
// /// containing the indices used for accessing `array` corresponding to the index of the broadcasted
// /// array `broadcast_idx`.
// pub fn call_ndarray_calc_broadcast_index<
// 'ctx,
// G: CodeGenerator + ?Sized,
// BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
// >(
// generator: &mut G,
// ctx: &mut CodeGenContext<'ctx, '_>,
// array: NDArrayValue<'ctx>,
// broadcast_idx: &BroadcastIdx,
// ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// let llvm_i32 = ctx.ctx.i32_type();
// let llvm_usize = generator.get_size_type(ctx.ctx);
// let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
// let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
//
// let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
// 32 => "__nac3_ndarray_calc_broadcast_idx",
// 64 => "__nac3_ndarray_calc_broadcast_idx64",
// bw => unreachable!("Unsupported size type bit width: {}", bw),
// };
// let ndarray_calc_broadcast_fn =
// ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
// let fn_type = llvm_usize.fn_type(
// &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
// false,
// );
//
// ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
// });
//
// let broadcast_size = broadcast_idx.size(ctx, generator);
// let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
//
// let array_dims = array.dim_sizes().base_ptr(ctx, generator);
// let array_ndims = array.load_ndims(ctx);
// let broadcast_idx_ptr = unsafe {
// broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
// };
//
// ctx.builder
// .build_call(
// ndarray_calc_broadcast_fn,
// &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
// "",
// )
// .unwrap();
//
// TypedArrayLikeAdapter::from(
// ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
// Box::new(|_, v| v.into_int_value()),
// Box::new(|_, v| v.into()),
// )
// }
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [`IntValue`] representing the
/// calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
let ndarray_calc_size_fn =
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
let begin = begin.unwrap_or_else(|| llvm_usize.const_zero());
let end = end.unwrap_or_else(|| dims.size(ctx, generator));
ctx.builder
.build_call(
ndarray_calc_size_fn,
&[
dims.base_ptr(ctx, generator).into(),
dims.size(ctx, generator).into(),
begin.into(),
end.into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()
}
/// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`]
/// containing `i32` indices of the flattened index.
///
/// * `index` - The index to compute the multidimensional index for.
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
generator: &G,
ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_void = ctx.ctx.void_type();
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
ctx.builder
.build_call(
ndarray_calc_nd_indices_fn,
&[
index.into(),
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
generator: &G,
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Indices,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
debug_assert_eq!(
IntType::try_from(indices.element_type(ctx, generator))
.map(IntType::get_bit_width)
.unwrap_or_default(),
llvm_i32.get_bit_width(),
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
);
debug_assert_eq!(
indices.size(ctx, generator).get_type().get_bit_width(),
llvm_usize.get_bit_width(),
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
);
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let index = ctx
.builder
.build_call(
ndarray_flatten_index_fn,
&[
ndarray_dims.base_ptr(ctx, generator).into(),
ndarray_num_dims.into(),
indices.base_ptr(ctx, generator).into(),
indices.size(ctx, generator).into(),
],
"",
)
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap();
index
}
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the
/// multidimensional index.
///
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
/// `NDArray`.
/// * `indices` - The multidimensional index to compute the flattened index for.
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None);
gen_for_callback_incrementing(
generator,
ctx,
llvm_usize.const_zero(),
(min_ndims, false),
|generator, ctx, _, idx| {
let idx = ctx.builder.build_int_sub(min_ndims, idx, "").unwrap();
let (lhs_dim_sz, rhs_dim_sz) = unsafe {
(
lhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None),
)
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let rhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
let lhs_eq_rhs = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap();
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
ctx.make_assert(
generator,
is_compatible,
"0:ValueError",
"operands could not be broadcast together",
[None, None, None],
ctx.current_loc,
);
Ok(())
},
llvm_usize.const_int(1, false),
)
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant {
match ty.get_bit_width() {
@ -965,28 +965,21 @@ where
})
}
fn get_irrt_ndarray_ptr_type<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
size_type: IntType<'ctx>,
) -> PointerType<'ctx> {
let i8_type = ctx.ctx.i8_type();
fn get_ndarray_struct_ptr<'ctx>(ctx: &'ctx Context, size_type: IntType<'ctx>) -> PointerType<'ctx> {
let i8_type = ctx.i8_type();
let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() };
let struct_ty = ndarray_ty.get_struct_type(ctx.ctx);
let struct_ty = ndarray_ty.fields().whole_struct.as_struct_type(ctx);
struct_ty.ptr_type(AddressSpace::default())
}
fn get_irrt_opaque_uint8_ptr_type<'ctx>(ctx: &CodeGenContext<'ctx, '_>) -> PointerType<'ctx> {
ctx.ctx.i8_type().ptr_type(AddressSpace::default())
}
pub fn call_nac3_ndarray_size<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NpArrayValue<'ctx>,
) -> IntValue<'ctx> {
let size_type = ndarray.ty.size_type;
let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_size", || {
size_type.fn_type(&[get_irrt_ndarray_ptr_type(ctx, size_type).into()], false)
size_type.fn_type(&[get_ndarray_struct_ptr(ctx.ctx, size_type).into()], false)
});
ctx.builder
@ -996,44 +989,3 @@ pub fn call_nac3_ndarray_size<'ctx>(
.unwrap_left()
.into_int_value()
}
pub fn call_nac3_ndarray_fill_generic<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
ndarray: NpArrayValue<'ctx>,
fill_value: BasicValueEnum<'ctx>,
) {
// Sanity check on type of `fill_value`
check_basic_types_match(ndarray.ty.elem_type, fill_value.get_type().as_basic_type_enum())
.unwrap();
let size_type = ndarray.ty.size_type;
let function =
get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_fill_generic", || {
ctx.ctx.void_type().fn_type(
&[
get_irrt_ndarray_ptr_type(ctx, size_type).into(), // NDArray<SizeT>* ndarray
get_irrt_opaque_uint8_ptr_type(ctx).into(), // uint8_t* pvalue
],
false,
)
});
// Put `fill_value` onto the stack and get a pointer to it, and that pointer will be `pvalue`
let pvalue = ctx.builder.build_alloca(ndarray.ty.elem_type, "fill_value").unwrap();
ctx.builder.build_store(pvalue, fill_value).unwrap();
// Cast pvalue to `uint8_t*`
let pvalue = ctx.builder.build_pointer_cast(pvalue, get_irrt_opaque_uint8_ptr_type(ctx), "").unwrap();
// Call the IRRT function
ctx.builder
.build_call(
function,
&[
ndarray.ptr.into(), // ndarray
pvalue.into(), // pvalue
],
"",
)
.unwrap();
}

View File

@ -7,7 +7,6 @@ use crate::{
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
},
};
use classes::NpArrayType;
use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{
attributes::{Attribute, AttributeLoc},
@ -477,11 +476,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
ctx, module, generator, unifier, top_level, type_cache, dtype,
);
let ndarray_ty = NpArrayType {
size_type: generator.get_size_type(ctx),
elem_type: element_type,
};
ndarray_ty.get_struct_type(ctx).as_basic_type_enum()
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
}
_ => unreachable!(

View File

@ -2,11 +2,15 @@ use crate::{
codegen::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayType, NDArrayValue,
NpArrayType, ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
},
expr::gen_binop_expr_with_values,
irrt::call_nac3_ndarray_fill_generic,
irrt::{
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
call_ndarray_calc_size,
},
llvm_intrinsics::{self, call_memcpy_generic},
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
CodeGenContext, CodeGenerator,
@ -22,7 +26,7 @@ use crate::{
typedef::{FunSignature, Type, TypeEnum},
},
};
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType};
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
use inkwell::{
types::BasicType,
values::{BasicValueEnum, IntValue, PointerValue},
@ -30,8 +34,6 @@ use inkwell::{
};
use nac3parser::ast::{Operator, StrRef};
use super::{classes::NpArrayValue, stmt::gen_return};
// /// Creates an uninitialized `NDArray` instance.
// fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
// generator: &mut G,
@ -2013,335 +2015,3 @@ use super::{classes::NpArrayValue, stmt::gen_return};
// Ok(())
// }
//
fn simple_assert<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
cond: IntValue<'ctx>,
msg: &str,
) where
G: CodeGenerator + ?Sized,
{
let mut full_msg = String::from("simple_assert failed: ");
full_msg.push_str(msg);
ctx.make_assert(
generator,
cond,
"0:ValueError",
full_msg.as_str(),
[None, None, None],
ctx.current_loc,
);
}
fn copy_array_slice<'ctx, G, Src, Dst>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
dst: Dst,
src: Src,
) where
G: CodeGenerator + ?Sized,
Dst: TypedArrayLikeMutator<'ctx, IntType<'ctx>>,
Src: TypedArrayLikeAccessor<'ctx, IntType<'ctx>>,
{
// Sanity check
let len_match = ctx
.builder
.build_int_compare(
IntPredicate::EQ,
src.size(ctx, generator),
dst.size(ctx, generator),
"len_match",
)
.unwrap();
simple_assert(generator, ctx, len_match, "copy_array_slice length mismatched");
let size_type = generator.get_size_type(ctx.ctx);
let init_val = size_type.const_zero();
let max_val = (dst.size(ctx, generator), false);
let incr_val = size_type.const_int(1, false);
gen_for_callback_incrementing(
generator,
ctx,
init_val,
max_val,
|generator, ctx, _hooks, idx| {
let value = src.get_typed(ctx, generator, &idx, Some("copy_array_slice.tmp"));
dst.set_typed(ctx, generator, &idx, value);
Ok(())
},
incr_val,
)
.unwrap();
}
pub fn alloca_ndarray_uninitialized<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_type: BasicTypeEnum<'ctx>,
ndims: IntValue<'ctx>,
name: Option<&str>,
) -> Result<NpArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
{
let size_type = generator.get_size_type(ctx.ctx);
let ndarray_ty = NpArrayType { size_type, elem_type };
let ndarray = ndarray_ty.var_alloc(generator, ctx, ndims, name);
Ok(ndarray)
}
pub struct Producer<'ctx, G: CodeGenerator + ?Sized, T> {
pub count: IntValue<'ctx>,
pub write_to_slice: Box<
dyn Fn(
&mut G,
&mut CodeGenContext<'ctx, '_>,
&TypedArrayLikeAdapter<'ctx, T>,
) -> Result<(), String>
+ 'ctx,
>,
}
/// TODO: UPDATE DOCUMENTATION
/// LLVM-typed implementation for generating a [`Producer`] that sets a list of ints.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
///
/// ### Notes on `shape`
///
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
///
/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to
/// learn how `shape` gets from being a Python user expression to here.
fn parse_input_shape_arg<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
) -> Result<Producer<'ctx, G, IntValue<'ctx>>, String>
where
G: CodeGenerator + ?Sized,
{
let size_type = generator.get_size_type(ctx.ctx);
match &*ctx.unifier.get_ty(shape_ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
// A list has to be a PointerValue
let shape_list = ListValue::from_ptr_val(shape.into_pointer_value(), size_type, None);
// Create `Producer`
let ndims = shape_list.load_size(ctx, Some("count"));
Ok(Producer {
count: ndims,
write_to_slice: Box::new(move |ctx, generator, dst_slice| {
// Basically iterate through the list and write to `dst_slice` accordingly
let init_val = size_type.const_zero();
let max_val = (ndims, false);
let incr_val = size_type.const_int(1, false);
gen_for_callback_incrementing(
ctx,
generator,
init_val,
max_val,
|generator, ctx, _hooks, idx| {
// Get the dimension at `idx`
let dim =
shape_list.data().get(ctx, generator, &idx, None).into_int_value();
// Cast `dim` to SizeT
let dim = ctx
.builder
.build_int_s_extend_or_bit_cast(dim, size_type, "dim_casted")
.unwrap();
// Write
dst_slice.set_typed(ctx, generator, &idx, dim);
Ok(())
},
incr_val,
)?;
Ok(())
}),
})
}
TypeEnum::TTuple { ty: tuple_types } => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
let ndims = tuple_types.len();
// A tuple has to be a StructValue
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
let shape_tuple = shape.into_struct_value();
Ok(Producer {
count: size_type.const_int(ndims as u64, false),
write_to_slice: Box::new(move |generator, ctx, dst_slice| {
for dim_i in 0..ndims {
// Get the dimension at `dim_i`
let dim = ctx
.builder
.build_extract_value(
shape_tuple,
dim_i as u32,
format!("dim{dim_i}").as_str(),
)
.unwrap()
.into_int_value();
// Cast `dim` to SizeT
let dim = ctx
.builder
.build_int_s_extend_or_bit_cast(dim, size_type, "dim_casted")
.unwrap();
// Write
dst_slice.set_typed(
ctx,
generator,
&size_type.const_int(dim_i as u64, false),
dim,
);
}
Ok(())
}),
})
}
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
{
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
// The value has to be an integer
let shape_int = shape.into_int_value();
Ok(Producer {
count: size_type.const_int(1, false),
write_to_slice: Box::new(move |generator, ctx, dst_slice| {
// Only index 0 is set with the input value
let dim_i = size_type.const_zero();
// Cast `shape_int` to SizeT
let dim = ctx
.builder
.build_int_s_extend_or_bit_cast(shape_int, size_type, "dim_casted")
.unwrap();
// Write
dst_slice.set_typed(ctx, generator, &dim_i, dim);
Ok(())
}),
})
}
_ => panic!("parse_input_shape_arg encountered unknown type"),
}
}
/// TODO: DOCUMENT ME
fn alloca_ndarray_uninitialized_shaped<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_type: BasicTypeEnum<'ctx>,
shape_producer: Producer<'ctx, G, IntValue<'ctx>>,
name: Option<&str>,
) -> Result<NpArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
{
// Allocate an uninitialized ndarray
let ndims = shape_producer.count;
let ndarray = alloca_ndarray_uninitialized(generator, ctx, elem_type, ndims, name)?;
// Fill `ndarray.shape` with `shape_producer`
(shape_producer.write_to_slice)(generator, ctx, &ndarray.shape_slice(ctx))?;
Ok(ndarray)
}
/// LLVM-typed implementation for generating the implementation for constructing an empty `NDArray`.
fn call_ndarray_empty_impl<'ctx, G>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: BasicValueEnum<'ctx>,
shape_ty: Type,
name: Option<&str>,
) -> Result<NpArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
{
let elem_type = ctx.get_llvm_type(generator, elem_ty);
let shape_producer = parse_input_shape_arg(generator, ctx, shape, shape_ty)?;
alloca_ndarray_uninitialized_shaped(generator, ctx, elem_type, shape_producer, name)
}
/// Generates LLVM IR for `np.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
let ndarray = call_ndarray_empty_impl(
generator,
context,
context.primitives.float,
shape,
shape_ty,
None,
)?;
Ok(ndarray.ptr)
}
/// Generates LLVM IR for `np.zeros`.
///
/// NOTE: Current `dtype` is always `float64`.
pub fn gen_ndarray_zeros<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
obj: &Option<(Type, ValueEnum<'ctx>)>,
fun: (&FunSignature, DefinitionId),
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
generator: &mut dyn CodeGenerator,
) -> Result<PointerValue<'ctx>, String> {
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let shape_ty = fun.0.args[0].ty;
let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
// Allocate an ndarray and fill it later
let ndarray = call_ndarray_empty_impl(
generator,
context,
context.primitives.float, // float64
shape,
shape_ty,
None,
)?;
// TRICK: The float64 type could be conveniently extracted out of `ndarray`
let float_type = ndarray.ty.elem_type.into_float_type();
// Fill the ndarray
call_nac3_ndarray_fill_generic(context, ndarray, float_type.const_float(1.0).into());
// Return our ndarray
println!("ndarray.ptr = {}", ndarray.ptr);
Ok(ndarray.ptr)
}

View File

@ -23,4 +23,4 @@ pub mod codegen;
pub mod symbol_resolver;
pub mod toplevel;
pub mod typecheck;
pub mod util;
pub mod util;

View File

@ -1193,13 +1193,14 @@ impl<'a> BuiltinBuilder<'a> {
self.ndarray_float,
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, obj, fun, args, generator| {
let func = match prim {
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
PrimDef::FunNpZeros => gen_ndarray_zeros,
PrimDef::FunNpOnes => todo!(), // gen_ndarray_ones,
_ => unreachable!(),
};
func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
todo!()
// let func = match prim {
// PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,
// PrimDef::FunNpZeros => gen_ndarray_zeros,
// PrimDef::FunNpOnes => gen_ndarray_ones,
// _ => unreachable!(),
// };
// func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum()))
}),
)
}

View File

@ -1,5 +0,0 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SizeVariant {
Bits32,
Bits64,
}

View File

@ -1,3 +0,0 @@
def run() -> int32:
hello = np_zeros((3, 4))
return 0

View File

@ -449,9 +449,6 @@ fn main() {
.create_target_machine(llvm_options.opt_level)
.expect("couldn't create target machine");
// NOTE: DEBUG PRINT
main.print_to_file("standalone.ll").unwrap();
let pass_options = PassBuilderOptions::create();
pass_options.set_merge_functions(true);
let passes = format!("default<O{}>", opt_level as u32);