From e719d9396d4f6fb46651a4e7ad198871dc682c9f Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 11 Jul 2024 00:34:06 +0800 Subject: [PATCH] asd --- flake.nix | 3 + nac3core/irrt/irrt_numpy_ndarray.hpp | 51 +++- nac3core/irrt/irrt_test.cpp | 34 +-- nac3core/src/codegen/classes.rs | 378 ++++++++++++++++++++++++-- nac3core/src/codegen/irrt/mod.rs | 194 ++++++++++--- nac3core/src/codegen/mod.rs | 7 +- nac3core/src/codegen/numpy.rs | 344 ++++++++++++++++++++++- nac3core/src/lib.rs | 1 + nac3core/src/toplevel/builtins.rs | 41 ++- nac3core/src/util.rs | 5 + nac3standalone/demo/src/my_ndarray.py | 3 + nac3standalone/src/main.rs | 3 + 12 files changed, 946 insertions(+), 118 deletions(-) create mode 100644 nac3core/src/util.rs create mode 100644 nac3standalone/demo/src/my_ndarray.py diff --git a/flake.nix b/flake.nix index a6ce5fce..79b1fa63 100644 --- a/flake.nix +++ b/flake.nix @@ -163,7 +163,10 @@ clippy pre-commit rustfmt + rust-analyzer ]; + # https://nixos.wiki/wiki/Rust#Shell.nix_example + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; }; devShells.x86_64-linux.msys2 = pkgs.mkShell { name = "nac3-dev-shell-msys2"; diff --git a/nac3core/irrt/irrt_numpy_ndarray.hpp b/nac3core/irrt/irrt_numpy_ndarray.hpp index 8e1f1d50..3caf8531 100644 --- a/nac3core/irrt/irrt_numpy_ndarray.hpp +++ b/nac3core/irrt/irrt_numpy_ndarray.hpp @@ -212,11 +212,11 @@ namespace { return this->size() * itemsize; } - void set_value_at_pelement(uint8_t* pelement, const uint8_t* pvalue) { + void set_pelement_value(uint8_t* pelement, const uint8_t* pvalue) { __builtin_memcpy(pelement, pvalue, itemsize); } - uint8_t* get_pelement(const SizeT *indices) { + uint8_t* get_pelement_by_indices(const SizeT *indices) { uint8_t* element = data; for (SizeT dim_i = 0; dim_i < ndims; dim_i++) element += indices[dim_i] * strides[dim_i]; @@ -229,7 +229,7 @@ namespace { SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * this->ndims); ndarray_util::set_indices_by_nth(this->ndims, this->shape, indices, nth); - return get_pelement(indices); + return get_pelement_by_indices(indices); } // Get pointer to the first element of this ndarray, assuming @@ -259,8 +259,8 @@ namespace { iter.set_indices_zero(); for (SizeT i = 0; i < this->size(); i++, iter.next()) { - uint8_t* pelement = get_pelement(iter.indices); - set_value_at_pelement(pelement, pvalue); + uint8_t* pelement = get_pelement_by_indices(iter.indices); + set_pelement_value(pelement, pvalue); } } @@ -283,8 +283,8 @@ namespace { if (!in_bounds(indices)) continue; - uint8_t* pelement = get_pelement(indices); - set_value_at_pelement(pelement, one_pvalue); + uint8_t* pelement = get_pelement_by_indices(indices); + set_pelement_value(pelement, one_pvalue); } } @@ -403,6 +403,43 @@ namespace { } } } + + // Simulates `this_ndarray[:] = src_ndarray`, with automatic broadcasting. + // Caution on https://github.com/numpy/numpy/issues/21744 + // Also see `NDArray::broadcast_to` + void assign_with(NDArray* src_ndarray) { + irrt_assert( + ndarray_util::can_broadcast_shape_to( + this->ndims, + this->shape, + src_ndarray->ndims, + src_ndarray->shape + ) + ); + + // Broadcast the `src_ndarray` to make the reading process *much* easier + SizeT* broadcasted_src_ndarray_strides = __builtin_alloca(sizeof(SizeT) * this->ndims); // Remember to allocate strides beforehand + NDArray broadcasted_src_ndarray = { + .ndims = this->ndims, + .shape = this->shape, + .strides = broadcasted_src_ndarray_strides + }; + src_ndarray->broadcast_to(&broadcasted_src_ndarray); + + // Using iter instead of `get_nth_pelement` because it is slightly faster + SizeT* indices = __builtin_alloca(sizeof(SizeT) * this->ndims); + auto iter = NDArrayIndicesIter { + .ndims = this->ndims, + .shape = this->shape, + .indices = indices + }; + const SizeT this_size = this->size(); + for (SizeT i = 0; i < this_size; i++, iter.next()) { + uint8_t* src_pelement = broadcasted_src_ndarray_strides->get_pelement_by_indices(indices); + uint8_t* this_pelement = this->get_pelement_by_indices(indices); + this->set_pelement_value(src_pelement, src_pelement); + } + } }; } diff --git a/nac3core/irrt/irrt_test.cpp b/nac3core/irrt/irrt_test.cpp index f6e67ff3..7142c865 100644 --- a/nac3core/irrt/irrt_test.cpp +++ b/nac3core/irrt/irrt_test.cpp @@ -81,7 +81,7 @@ void __print_ndarray_aux(const char *format, bool first, bool last, SizeT* curso SizeT* indices = (SizeT*) __builtin_alloca(sizeof(SizeT) * ndarray->ndims); for (SizeT i = 0; i < dim; i++) { ndarray_util::set_indices_by_nth(ndarray->ndims, ndarray->shape, indices, *cursor); - ElementT* pelement = (ElementT*) ndarray->get_pelement(indices); + ElementT* pelement = (ElementT*) ndarray->get_pelement_by_indices(indices); ElementT element = *pelement; if (i != 0) printf(", "); // List delimiter @@ -165,7 +165,7 @@ void test_ndarray_indices_iter_normal() { int32_t shape[3] = { 1, 2, 3 }; int32_t indices[3] = { 0, 0, 0 }; auto iter = NDArrayIndicesIter { - .ndims = 3u, + .ndims = 3, .shape = shape, .indices = indices }; @@ -394,10 +394,10 @@ void test_ndslice_1() { assert_arrays_match("shape", "%d", dst_ndims, expected_shape, dst_ndarray.shape); assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides); - assert_values_match("dst_ndarray[0, 0]", "%f", 5.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 0 }))); - assert_values_match("dst_ndarray[0, 1]", "%f", 7.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0, 1 }))); - assert_values_match("dst_ndarray[1, 0]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 0 }))); - assert_values_match("dst_ndarray[1, 1]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1, 1 }))); + assert_values_match("dst_ndarray[0, 0]", "%f", 5.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0, 0 }))); + assert_values_match("dst_ndarray[0, 1]", "%f", 7.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0, 1 }))); + assert_values_match("dst_ndarray[1, 0]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1, 0 }))); + assert_values_match("dst_ndarray[1, 1]", "%f", 11.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1, 1 }))); } void test_ndslice_2() { @@ -471,8 +471,8 @@ void test_ndslice_2() { assert_arrays_match("strides", "%d", dst_ndims, expected_strides, dst_ndarray.strides); // [5.0, 3.0] - assert_values_match("dst_ndarray[0]", "%f", 11.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 0 }))); - assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement((int32_t[dst_ndims]) { 1 }))); + assert_values_match("dst_ndarray[0]", "%f", 11.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 0 }))); + assert_values_match("dst_ndarray[1]", "%f", 9.0, *((double *) dst_ndarray.get_pelement_by_indices((int32_t[dst_ndims]) { 1 }))); } void test_can_broadcast_shape() { @@ -618,15 +618,15 @@ void test_ndarray_broadcast_1() { assert_arrays_match("dst_ndarray->strides", "%d", dst_ndims, (int32_t[]) { 0, 0, 8 }, dst_ndarray.strides); - assert_values_match("dst_ndarray[0, 0, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 0}))); - assert_values_match("dst_ndarray[0, 0, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 1}))); - assert_values_match("dst_ndarray[0, 0, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 2}))); - assert_values_match("dst_ndarray[0, 0, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 0, 3}))); - assert_values_match("dst_ndarray[0, 1, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 0}))); - assert_values_match("dst_ndarray[0, 1, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 1}))); - assert_values_match("dst_ndarray[0, 1, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 2}))); - assert_values_match("dst_ndarray[0, 1, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {0, 1, 3}))); - assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement((int32_t[]) {1, 2, 3}))); + assert_values_match("dst_ndarray[0, 0, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 0}))); + assert_values_match("dst_ndarray[0, 0, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 1}))); + assert_values_match("dst_ndarray[0, 0, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 2}))); + assert_values_match("dst_ndarray[0, 0, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 0, 3}))); + assert_values_match("dst_ndarray[0, 1, 0]", "%f", 19.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 0}))); + assert_values_match("dst_ndarray[0, 1, 1]", "%f", 29.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 1}))); + assert_values_match("dst_ndarray[0, 1, 2]", "%f", 39.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 2}))); + assert_values_match("dst_ndarray[0, 1, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {0, 1, 3}))); + assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {1, 2, 3}))); } int main() { diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index a173035d..032ba0b4 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -2,7 +2,8 @@ use crate::codegen::{ // irrt::{call_ndarray_calc_size, call_ndarray_flatten_index}, llvm_intrinsics::call_int_umin, stmt::gen_for_callback_incrementing, - CodeGenContext, CodeGenerator, + CodeGenContext, + CodeGenerator, }; use inkwell::context::Context; use inkwell::types::{ArrayType, BasicType, StructType}; @@ -12,6 +13,7 @@ use inkwell::{ values::{BasicValueEnum, IntValue, PointerValue}, AddressSpace, IntPredicate, }; +use itertools::Itertools; /// A LLVM type that is used to represent a non-primitive type in NAC3. pub trait ProxyType<'ctx>: Into { @@ -1208,25 +1210,27 @@ impl<'ctx> NDArrayType<'ctx> { ctx: &'ctx Context, dtype: BasicTypeEnum<'ctx>, ) -> Self { - let llvm_usize = generator.get_size_type(ctx); + todo!() - // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } - // - // * num_dims: Number of dimensions in the array - // * dims: Pointer to an array containing the size of each dimension - // * data: Pointer to an array containing the array data - let llvm_ndarray = ctx - .struct_type( - &[ - llvm_usize.into(), - llvm_usize.ptr_type(AddressSpace::default()).into(), - dtype.ptr_type(AddressSpace::default()).into(), - ], - false, - ) - .ptr_type(AddressSpace::default()); + // let llvm_usize = generator.get_size_type(ctx); - NDArrayType::from_type(llvm_ndarray, llvm_usize) + // // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } + // // + // // * num_dims: Number of dimensions in the array + // // * dims: Pointer to an array containing the size of each dimension + // // * data: Pointer to an array containing the array data + // let llvm_ndarray = ctx + // .struct_type( + // &[ + // llvm_usize.into(), + // llvm_usize.ptr_type(AddressSpace::default()).into(), + // dtype.ptr_type(AddressSpace::default()).into(), + // ], + // false, + // ) + // .ptr_type(AddressSpace::default()); + + // NDArrayType::from_type(llvm_ndarray, llvm_usize) } /// Creates an [`NDArrayType`] from a [`PointerType`]. @@ -1763,3 +1767,341 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, for NDArrayDataProxy<'ctx, '_> { } + +#[derive(Debug, Clone, Copy)] +pub struct StructField<'ctx> { + /// The GEP index of this struct field. + pub gep_index: u32, + /// Name of this struct field. + /// + /// Used for generating names. + pub name: &'static str, + /// The type of this struct field. + pub ty: BasicTypeEnum<'ctx>, +} + +pub struct StructFields<'ctx> { + /// Name of the struct. + /// + /// Used for generating names. + pub name: &'static str, + + /// All the [`StructField`]s of this struct. + /// + /// **NOTE:** The index position of a [`StructField`] + /// matches the element's [`StructField::index`]. + pub fields: Vec>, +} + +struct StructFieldsBuilder<'ctx> { + gep_index_counter: u32, + /// Name of the struct to be built. + name: &'static str, + fields: Vec>, +} + +impl<'ctx> StructField<'ctx> { + /// TODO: DOCUMENT ME + pub fn gep( + &self, + ctx: &CodeGenContext<'ctx, '_>, + struct_ptr: PointerValue<'ctx>, + ) -> PointerValue<'ctx> { + let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to use i32 for GEP like that + unsafe { + ctx.builder + .build_in_bounds_gep( + struct_ptr, + &[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)], + self.name, + ) + .unwrap() + } + } + + /// TODO: DOCUMENT ME + pub fn load( + &self, + ctx: &CodeGenContext<'ctx, '_>, + struct_ptr: PointerValue<'ctx>, + ) -> BasicValueEnum<'ctx> { + ctx.builder.build_load(self.gep(ctx, struct_ptr), self.name).unwrap() + } + + /// TODO: DOCUMENT ME + pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, struct_ptr: PointerValue<'ctx>, value: V) + where + V: BasicValue<'ctx>, + { + ctx.builder.build_store(self.gep(ctx, struct_ptr), value).unwrap(); + } +} + +type IsInstanceError = String; +type IsInstanceResult = Result<(), IsInstanceError>; + +pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> IsInstanceResult +where + A: BasicType<'ctx>, + B: BasicType<'ctx>, +{ + let expected = expected.as_basic_type_enum(); + let got = got.as_basic_type_enum(); + + // Put those logic into here, + // otherwise there is always a fallback reporting on any kind of mismatch + match (expected, got) { + (BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => { + if expected.get_bit_width() != got.get_bit_width() { + return Err(format!( + "Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))" + )); + } + } + (expected, got) => { + if expected != got { + return Err(format!("Expected {expected}, got {got}")); + } + } + } + Ok(()) +} + +impl<'ctx> StructFields<'ctx> { + pub fn num_fields(&self) -> u32 { + self.fields.len() as u32 + } + + pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> { + let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec(); + ctx.struct_type(llvm_fields.as_slice(), false) + } + + pub fn is_type(&self, scrutinee: StructType<'ctx>) -> IsInstanceResult { + // Check scrutinee's number of struct fields + if scrutinee.count_fields() != self.num_fields() { + return Err(format!( + "Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}", + struct_name = self.name, + expected_count = self.num_fields(), + got_count = scrutinee.count_fields(), + )); + } + + // Check the scrutinee's field types + for field in self.fields.iter() { + let expected_field_ty = field.ty; + let got_field_ty = scrutinee.get_field_type_at_index(field.gep_index).unwrap(); + + if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) { + return Err(format!( + "Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}", + gep_index = field.gep_index, + struct_name = self.name, + field_name = field.name, + )); + } + } + + // Done + Ok(()) + } +} + +impl<'ctx> StructFieldsBuilder<'ctx> { + fn start(name: &'static str) -> Self { + StructFieldsBuilder { gep_index_counter: 0, name, fields: Vec::new() } + } + + fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> { + let index = self.gep_index_counter; + self.gep_index_counter += 1; + + let field = StructField { gep_index: index, name, ty }; + self.fields.push(field); // Register into self.fields + + field // Return to the caller to conveniently let them do whatever they want + } + + fn end(self) -> StructFields<'ctx> { + StructFields { name: self.name, fields: self.fields } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct NpArrayType<'ctx> { + pub size_type: IntType<'ctx>, + pub elem_type: BasicTypeEnum<'ctx>, +} + +pub struct NpArrayStructFields<'ctx> { + pub whole_struct: StructFields<'ctx>, + pub data: StructField<'ctx>, + pub itemsize: StructField<'ctx>, + pub ndims: StructField<'ctx>, + pub shape: StructField<'ctx>, + pub strides: StructField<'ctx>, +} + +impl<'ctx> NpArrayType<'ctx> { + pub fn new_opaque_elem( + ctx: &CodeGenContext<'ctx, '_>, + size_type: IntType<'ctx>, + ) -> NpArrayType<'ctx> { + NpArrayType { size_type, elem_type: ctx.ctx.i8_type().as_basic_type_enum() } + } + + pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> { + self.fields().whole_struct.get_struct_type(ctx) + } + + pub fn fields(&self) -> NpArrayStructFields<'ctx> { + let mut builder = StructFieldsBuilder::start("NpArray"); + + let addrspace = AddressSpace::default(); + + let byte_type = self.size_type.get_context().i8_type(); + + // Make sure the struct matches PERFECTLY with that defined in `nac3core/irrt`. + let data = builder.add_field("data", byte_type.ptr_type(addrspace).into()); + let itemsize = builder.add_field("itemsize", self.size_type.into()); + let ndims = builder.add_field("ndims", self.size_type.into()); + let shape = builder.add_field("shape", self.size_type.ptr_type(addrspace).into()); + let strides = builder.add_field("strides", self.size_type.ptr_type(addrspace).into()); + + NpArrayStructFields { whole_struct: builder.end(), data, itemsize, ndims, shape, strides } + } + + /// Allocate an `ndarray` on stack, with the following notes: + /// + /// - `ndarray.ndims` will be initialized to `in_ndims`. + /// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`. + /// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`, + /// all with empty/uninitialized values. + pub fn var_alloc( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + in_ndims: IntValue<'ctx>, + name: Option<&str>, + ) -> NpArrayValue<'ctx> + where + G: CodeGenerator + ?Sized, + { + let ptr = generator + .gen_var_alloc(ctx, self.get_struct_type(ctx.ctx).as_basic_type_enum(), name) + .unwrap(); + + // Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides` + let allocated_shape = generator + .gen_array_var_alloc( + ctx, + self.size_type.as_basic_type_enum(), + in_ndims, + Some("allocated_shape"), + ) + .unwrap(); + let allocated_strides = generator + .gen_array_var_alloc( + ctx, + self.size_type.as_basic_type_enum(), + in_ndims, + Some("allocated_strides"), + ) + .unwrap(); + + let value = NpArrayValue { ty: *self, ptr }; + value.store_ndims(ctx, in_ndims); + value.store_itemsize(ctx, self.elem_type.size_of().unwrap()); + value.store_shape(ctx, allocated_shape.base_ptr(ctx, generator)); + value.store_strides(ctx, allocated_strides.base_ptr(ctx, generator)); + + return value; + } +} + +#[derive(Debug, Clone, Copy)] +pub struct NpArrayValue<'ctx> { + pub ty: NpArrayType<'ctx>, + pub ptr: PointerValue<'ctx>, +} + +impl<'ctx> NpArrayValue<'ctx> { + pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + let field = self.ty.fields().ndims; + field.load(ctx, self.ptr).into_int_value() + } + + pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { + let field = self.ty.fields().ndims; + field.store(ctx, self.ptr, value); + } + + pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + let field = self.ty.fields().itemsize; + field.load(ctx, self.ptr).into_int_value() + } + + pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, value: IntValue<'ctx>) { + let field = self.ty.fields().itemsize; + field.store(ctx, self.ptr, value); + } + + pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let field = self.ty.fields().shape; + field.load(ctx, self.ptr).into_pointer_value() + } + + pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { + let field = self.ty.fields().shape; + field.store(ctx, self.ptr, value); + } + + pub fn load_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let field = self.ty.fields().strides; + field.load(ctx, self.ptr).into_pointer_value() + } + + pub fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { + let field = self.ty.fields().strides; + field.store(ctx, self.ptr, value); + } + + /// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!! + pub fn shape_slice( + &self, + ctx: &CodeGenContext<'ctx, '_>, + ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + // Get the pointer to `shape` + let field = self.ty.fields().shape; + let shape = field.load(ctx, self.ptr).into_pointer_value(); + + // Load `ndims` + let ndims = self.load_ndims(ctx); + + TypedArrayLikeAdapter { + adapted: ArraySliceValue(shape, ndims, Some(field.name)), + downcast_fn: Box::new(|_ctx, x| x.into_int_value()), + upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()), + } + } + + /// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!! + pub fn strides_slice( + &self, + ctx: &CodeGenContext<'ctx, '_>, + ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { + // Get the pointer to `strides` + let field = self.ty.fields().strides; + let strides = field.load(ctx, self.ptr).into_pointer_value(); + + // Load `ndims` + let ndims = self.load_ndims(ctx); + + TypedArrayLikeAdapter { + adapted: ArraySliceValue(strides, ndims, Some(field.name)), + downcast_fn: Box::new(|_ctx, x| x.into_int_value()), + upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()), + } + } +} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index dd9edbfc..284d9470 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,11 +1,11 @@ -use crate::typecheck::typedef::Type; +use crate::{typecheck::typedef::Type, util::SizeVariant}; mod test; use super::{ classes::{ - ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, - TypedArrayLikeAdapter, UntypedArrayLikeAccessor, + check_basic_types_match, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, + NDArrayValue, NpArrayType, NpArrayValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, }, llvm_intrinsics, CodeGenContext, CodeGenerator, }; @@ -16,8 +16,8 @@ use inkwell::{ context::Context, memory_buffer::MemoryBuffer, module::Module, - types::{BasicTypeEnum, IntType}, - values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue}, + types::{BasicType, BasicTypeEnum, FunctionType, IntType, PointerType}, + values::{BasicValue, BasicValueEnum, CallSiteValue, FloatValue, FunctionValue, IntValue}, AddressSpace, IntPredicate, }; use itertools::Either; @@ -583,7 +583,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // { // let llvm_usize = generator.get_size_type(ctx.ctx); // let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); -// +// // let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { // 32 => "__nac3_ndarray_calc_size", // 64 => "__nac3_ndarray_calc_size64", @@ -597,7 +597,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| { // ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) // }); -// +// // let begin = begin.unwrap_or_else(|| llvm_usize.const_zero()); // let end = end.unwrap_or_else(|| dims.size(ctx, generator)); // ctx.builder @@ -616,7 +616,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // .map(Either::unwrap_left) // .unwrap() // } -// +// // /// Generates a call to `__nac3_ndarray_calc_nd_indices`. Returns a [`TypeArrayLikeAdpater`] // /// containing `i32` indices of the flattened index. // /// @@ -634,7 +634,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // let llvm_usize = generator.get_size_type(ctx.ctx); // let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); // let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); -// +// // let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { // 32 => "__nac3_ndarray_calc_nd_indices", // 64 => "__nac3_ndarray_calc_nd_indices64", @@ -646,15 +646,15 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()], // false, // ); -// +// // ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) // }); -// +// // let ndarray_num_dims = ndarray.load_ndims(ctx); // let ndarray_dims = ndarray.dim_sizes(); -// +// // let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap(); -// +// // ctx.builder // .build_call( // ndarray_calc_nd_indices_fn, @@ -667,14 +667,14 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // "", // ) // .unwrap(); -// +// // TypedArrayLikeAdapter::from( // ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None), // Box::new(|_, v| v.into_int_value()), // Box::new(|_, v| v.into()), // ) // } -// +// // fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( // generator: &G, // ctx: &CodeGenContext<'ctx, '_>, @@ -687,10 +687,10 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // { // let llvm_i32 = ctx.ctx.i32_type(); // let llvm_usize = generator.get_size_type(ctx.ctx); -// +// // let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); // let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); -// +// // debug_assert_eq!( // IntType::try_from(indices.element_type(ctx, generator)) // .map(IntType::get_bit_width) @@ -703,7 +703,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // llvm_usize.get_bit_width(), // "Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`" // ); -// +// // let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { // 32 => "__nac3_ndarray_flatten_index", // 64 => "__nac3_ndarray_flatten_index64", @@ -715,13 +715,13 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()], // false, // ); -// +// // ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) // }); -// +// // let ndarray_num_dims = ndarray.load_ndims(ctx); // let ndarray_dims = ndarray.dim_sizes(); -// +// // let index = ctx // .builder // .build_call( @@ -738,10 +738,10 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // .map(|v| v.map_left(BasicValueEnum::into_int_value)) // .map(Either::unwrap_left) // .unwrap(); -// +// // index // } -// +// // /// Generates a call to `__nac3_ndarray_flatten_index`. Returns the flattened index for the // /// multidimensional index. // /// @@ -760,7 +760,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // { // call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices) // } -// +// // /// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of // /// dimension and size of each dimension of the resultant `ndarray`. // pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>( @@ -771,7 +771,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { // let llvm_usize = generator.get_size_type(ctx.ctx); // let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); -// +// // let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { // 32 => "__nac3_ndarray_calc_broadcast", // 64 => "__nac3_ndarray_calc_broadcast64", @@ -789,14 +789,14 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // ], // false, // ); -// +// // ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) // }); -// +// // let lhs_ndims = lhs.load_ndims(ctx); // let rhs_ndims = rhs.load_ndims(ctx); // let min_ndims = llvm_intrinsics::call_int_umin(ctx, lhs_ndims, rhs_ndims, None); -// +// // gen_for_callback_incrementing( // generator, // ctx, @@ -810,7 +810,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // rhs.dim_sizes().get_typed_unchecked(ctx, generator, &idx, None), // ) // }; -// +// // let llvm_usize_const_one = llvm_usize.const_int(1, false); // let lhs_eqz = ctx // .builder @@ -821,14 +821,14 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "") // .unwrap(); // let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap(); -// +// // let lhs_eq_rhs = ctx // .builder // .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "") // .unwrap(); -// +// // let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap(); -// +// // ctx.make_assert( // generator, // is_compatible, @@ -837,13 +837,13 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // [None, None, None], // ctx.current_loc, // ); -// +// // Ok(()) // }, // llvm_usize.const_int(1, false), // ) // .unwrap(); -// +// // let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); // let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); // let lhs_ndims = lhs.load_ndims(ctx); @@ -851,7 +851,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // let rhs_ndims = rhs.load_ndims(ctx); // let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap(); // let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None); -// +// // ctx.builder // .build_call( // ndarray_calc_broadcast_fn, @@ -865,14 +865,14 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // "", // ) // .unwrap(); -// +// // TypedArrayLikeAdapter::from( // out_dims, // Box::new(|_, v| v.into_int_value()), // Box::new(|_, v| v.into()), // ) // } -// +// // /// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] // /// containing the indices used for accessing `array` corresponding to the index of the broadcasted // /// array `broadcast_idx`. @@ -890,7 +890,7 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // let llvm_usize = generator.get_size_type(ctx.ctx); // let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); // let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); -// +// // let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { // 32 => "__nac3_ndarray_calc_broadcast_idx", // 64 => "__nac3_ndarray_calc_broadcast_idx64", @@ -902,19 +902,19 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()], // false, // ); -// +// // ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) // }); -// +// // let broadcast_size = broadcast_idx.size(ctx, generator); // let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); -// +// // let array_dims = array.dim_sizes().base_ptr(ctx, generator); // let array_ndims = array.load_ndims(ctx); // let broadcast_idx_ptr = unsafe { // broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None) // }; -// +// // ctx.builder // .build_call( // ndarray_calc_broadcast_fn, @@ -922,10 +922,118 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo // "", // ) // .unwrap(); -// +// // TypedArrayLikeAdapter::from( // ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None), // Box::new(|_, v| v.into_int_value()), // Box::new(|_, v| v.into()), // ) // } + +fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant { + match ty.get_bit_width() { + 32 => SizeVariant::Bits32, + 64 => SizeVariant::Bits64, + _ => unreachable!("Unsupported int type bit width {}", ty.get_bit_width()), + } +} + +fn get_size_type_dependent_function<'ctx, BuildFuncTypeFn>( + ctx: &CodeGenContext<'ctx, '_>, + size_type: IntType<'ctx>, + base_name: &str, + build_func_type: BuildFuncTypeFn, +) -> FunctionValue<'ctx> +where + BuildFuncTypeFn: Fn() -> FunctionType<'ctx>, +{ + let mut fn_name = base_name.to_owned(); + match get_size_variant(size_type) { + SizeVariant::Bits32 => { + // The original fn_name is the correct function name + } + SizeVariant::Bits64 => { + // Append "64" at the end, this is the naming convention for 64-bit + fn_name.push_str("64"); + } + } + + // Get (or declare then get if does not exist) the corresponding function + ctx.module.get_function(&fn_name).unwrap_or_else(|| { + let fn_type = build_func_type(); + ctx.module.add_function(&fn_name, fn_type, None) + }) +} + +fn get_irrt_ndarray_ptr_type<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + size_type: IntType<'ctx>, +) -> PointerType<'ctx> { + let i8_type = ctx.ctx.i8_type(); + + let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() }; + let struct_ty = ndarray_ty.get_struct_type(ctx.ctx); + struct_ty.ptr_type(AddressSpace::default()) +} + +fn get_irrt_opaque_uint8_ptr_type<'ctx>(ctx: &CodeGenContext<'ctx, '_>) -> PointerType<'ctx> { + ctx.ctx.i8_type().ptr_type(AddressSpace::default()) +} + +pub fn call_nac3_ndarray_size<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NpArrayValue<'ctx>, +) -> IntValue<'ctx> { + let size_type = ndarray.ty.size_type; + let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_size", || { + size_type.fn_type(&[get_irrt_ndarray_ptr_type(ctx, size_type).into()], false) + }); + + ctx.builder + .build_call(function, &[ndarray.ptr.into()], "size") + .unwrap() + .try_as_basic_value() + .unwrap_left() + .into_int_value() +} + +pub fn call_nac3_ndarray_fill_generic<'ctx>( + ctx: &CodeGenContext<'ctx, '_>, + ndarray: NpArrayValue<'ctx>, + fill_value: BasicValueEnum<'ctx>, +) { + // Sanity check on type of `fill_value` + check_basic_types_match(ndarray.ty.elem_type, fill_value.get_type().as_basic_type_enum()) + .unwrap(); + + let size_type = ndarray.ty.size_type; + let function = + get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_fill_generic", || { + ctx.ctx.void_type().fn_type( + &[ + get_irrt_ndarray_ptr_type(ctx, size_type).into(), // NDArray* ndarray + get_irrt_opaque_uint8_ptr_type(ctx).into(), // uint8_t* pvalue + ], + false, + ) + }); + + // Put `fill_value` onto the stack and get a pointer to it, and that pointer will be `pvalue` + let pvalue = ctx.builder.build_alloca(ndarray.ty.elem_type, "fill_value").unwrap(); + ctx.builder.build_store(pvalue, fill_value).unwrap(); + + // Cast pvalue to `uint8_t*` + let pvalue = ctx.builder.build_pointer_cast(pvalue, get_irrt_opaque_uint8_ptr_type(ctx), "").unwrap(); + + // Call the IRRT function + ctx.builder + .build_call( + function, + &[ + ndarray.ptr.into(), // ndarray + pvalue.into(), // pvalue + ], + "", + ) + .unwrap(); +} diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 17952369..9fd955c3 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -7,6 +7,7 @@ use crate::{ typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, }, }; +use classes::NpArrayType; use crossbeam::channel::{unbounded, Receiver, Sender}; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -476,7 +477,11 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( ctx, module, generator, unifier, top_level, type_cache, dtype, ); - NDArrayType::new(generator, ctx, element_type).as_base_type().into() + let ndarray_ty = NpArrayType { + size_type: generator.get_size_type(ctx), + elem_type: element_type, + }; + ndarray_ty.get_struct_type(ctx).as_basic_type_enum() } _ => unreachable!( diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 964fd98f..bbc55f79 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -2,15 +2,11 @@ use crate::{ codegen::{ classes::{ ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayType, NDArrayValue, - ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, + NpArrayType, ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, expr::gen_binop_expr_with_values, - irrt::{ - // calculate_len_for_slice_range, call_ndarray_calc_broadcast, - // call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, - // call_ndarray_calc_size, - }, + irrt::call_nac3_ndarray_fill_generic, llvm_intrinsics::{self, call_memcpy_generic}, stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, CodeGenContext, CodeGenerator, @@ -26,7 +22,7 @@ use crate::{ typedef::{FunSignature, Type, TypeEnum}, }, }; -use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType}; +use inkwell::types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType}; use inkwell::{ types::BasicType, values::{BasicValueEnum, IntValue, PointerValue}, @@ -34,6 +30,8 @@ use inkwell::{ }; use nac3parser::ast::{Operator, StrRef}; +use super::{classes::NpArrayValue, stmt::gen_return}; + // /// Creates an uninitialized `NDArray` instance. // fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>( // generator: &mut G, @@ -2015,3 +2013,335 @@ use nac3parser::ast::{Operator, StrRef}; // Ok(()) // } // + +fn simple_assert<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + cond: IntValue<'ctx>, + msg: &str, +) where + G: CodeGenerator + ?Sized, +{ + let mut full_msg = String::from("simple_assert failed: "); + full_msg.push_str(msg); + ctx.make_assert( + generator, + cond, + "0:ValueError", + full_msg.as_str(), + [None, None, None], + ctx.current_loc, + ); +} + +fn copy_array_slice<'ctx, G, Src, Dst>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dst: Dst, + src: Src, +) where + G: CodeGenerator + ?Sized, + Dst: TypedArrayLikeMutator<'ctx, IntType<'ctx>>, + Src: TypedArrayLikeAccessor<'ctx, IntType<'ctx>>, +{ + // Sanity check + let len_match = ctx + .builder + .build_int_compare( + IntPredicate::EQ, + src.size(ctx, generator), + dst.size(ctx, generator), + "len_match", + ) + .unwrap(); + simple_assert(generator, ctx, len_match, "copy_array_slice length mismatched"); + + let size_type = generator.get_size_type(ctx.ctx); + + let init_val = size_type.const_zero(); + let max_val = (dst.size(ctx, generator), false); + let incr_val = size_type.const_int(1, false); + gen_for_callback_incrementing( + generator, + ctx, + init_val, + max_val, + |generator, ctx, _hooks, idx| { + let value = src.get_typed(ctx, generator, &idx, Some("copy_array_slice.tmp")); + dst.set_typed(ctx, generator, &idx, value); + Ok(()) + }, + incr_val, + ) + .unwrap(); +} + +pub fn alloca_ndarray_uninitialized<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_type: BasicTypeEnum<'ctx>, + ndims: IntValue<'ctx>, + name: Option<&str>, +) -> Result, String> +where + G: CodeGenerator + ?Sized, +{ + let size_type = generator.get_size_type(ctx.ctx); + let ndarray_ty = NpArrayType { size_type, elem_type }; + let ndarray = ndarray_ty.var_alloc(generator, ctx, ndims, name); + Ok(ndarray) +} + +pub struct Producer<'ctx, G: CodeGenerator + ?Sized, T> { + pub count: IntValue<'ctx>, + pub write_to_slice: Box< + dyn Fn( + &mut G, + &mut CodeGenContext<'ctx, '_>, + &TypedArrayLikeAdapter<'ctx, T>, + ) -> Result<(), String> + + 'ctx, + >, +} + +/// TODO: UPDATE DOCUMENTATION +/// LLVM-typed implementation for generating a [`Producer`] that sets a list of ints. +/// +/// * `elem_ty` - The element type of the `NDArray`. +/// * `shape` - The `shape` parameter used to construct the `NDArray`. +/// +/// ### Notes on `shape` +/// +/// Just like numpy, the `shape` argument can be: +/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])` +/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))` +/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` +/// +/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to +/// learn how `shape` gets from being a Python user expression to here. +fn parse_input_shape_arg<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + shape: BasicValueEnum<'ctx>, + shape_ty: Type, +) -> Result>, String> +where + G: CodeGenerator + ?Sized, +{ + let size_type = generator.get_size_type(ctx.ctx); + + match &*ctx.unifier.get_ty(shape_ty) { + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + // 1. A list of ints; e.g., `np.empty([600, 800, 3])` + + // A list has to be a PointerValue + let shape_list = ListValue::from_ptr_val(shape.into_pointer_value(), size_type, None); + + // Create `Producer` + let ndims = shape_list.load_size(ctx, Some("count")); + Ok(Producer { + count: ndims, + write_to_slice: Box::new(move |ctx, generator, dst_slice| { + // Basically iterate through the list and write to `dst_slice` accordingly + let init_val = size_type.const_zero(); + let max_val = (ndims, false); + let incr_val = size_type.const_int(1, false); + gen_for_callback_incrementing( + ctx, + generator, + init_val, + max_val, + |generator, ctx, _hooks, idx| { + // Get the dimension at `idx` + let dim = + shape_list.data().get(ctx, generator, &idx, None).into_int_value(); + + // Cast `dim` to SizeT + let dim = ctx + .builder + .build_int_s_extend_or_bit_cast(dim, size_type, "dim_casted") + .unwrap(); + + // Write + dst_slice.set_typed(ctx, generator, &idx, dim); + Ok(()) + }, + incr_val, + )?; + Ok(()) + }), + }) + } + TypeEnum::TTuple { ty: tuple_types } => { + // 2. A tuple of ints; e.g., `np.empty((600, 800, 3))` + + // Get the length/size of the tuple, which also happens to be the value of `ndims`. + let ndims = tuple_types.len(); + + // A tuple has to be a StructValue + // Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM. + let shape_tuple = shape.into_struct_value(); + + Ok(Producer { + count: size_type.const_int(ndims as u64, false), + write_to_slice: Box::new(move |generator, ctx, dst_slice| { + for dim_i in 0..ndims { + // Get the dimension at `dim_i` + let dim = ctx + .builder + .build_extract_value( + shape_tuple, + dim_i as u32, + format!("dim{dim_i}").as_str(), + ) + .unwrap() + .into_int_value(); + + // Cast `dim` to SizeT + let dim = ctx + .builder + .build_int_s_extend_or_bit_cast(dim, size_type, "dim_casted") + .unwrap(); + + // Write + dst_slice.set_typed( + ctx, + generator, + &size_type.const_int(dim_i as u64, false), + dim, + ); + } + Ok(()) + }), + }) + } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() => + { + // 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])` + + // The value has to be an integer + let shape_int = shape.into_int_value(); + + Ok(Producer { + count: size_type.const_int(1, false), + write_to_slice: Box::new(move |generator, ctx, dst_slice| { + // Only index 0 is set with the input value + let dim_i = size_type.const_zero(); + + // Cast `shape_int` to SizeT + let dim = ctx + .builder + .build_int_s_extend_or_bit_cast(shape_int, size_type, "dim_casted") + .unwrap(); + + // Write + dst_slice.set_typed(ctx, generator, &dim_i, dim); + Ok(()) + }), + }) + } + _ => panic!("parse_input_shape_arg encountered unknown type"), + } +} + +/// TODO: DOCUMENT ME +fn alloca_ndarray_uninitialized_shaped<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_type: BasicTypeEnum<'ctx>, + shape_producer: Producer<'ctx, G, IntValue<'ctx>>, + name: Option<&str>, +) -> Result, String> +where + G: CodeGenerator + ?Sized, +{ + // Allocate an uninitialized ndarray + let ndims = shape_producer.count; + let ndarray = alloca_ndarray_uninitialized(generator, ctx, elem_type, ndims, name)?; + + // Fill `ndarray.shape` with `shape_producer` + (shape_producer.write_to_slice)(generator, ctx, &ndarray.shape_slice(ctx))?; + Ok(ndarray) +} + +/// LLVM-typed implementation for generating the implementation for constructing an empty `NDArray`. +fn call_ndarray_empty_impl<'ctx, G>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + shape: BasicValueEnum<'ctx>, + shape_ty: Type, + name: Option<&str>, +) -> Result, String> +where + G: CodeGenerator + ?Sized, +{ + let elem_type = ctx.get_llvm_type(generator, elem_ty); + let shape_producer = parse_input_shape_arg(generator, ctx, shape, shape_ty)?; + alloca_ndarray_uninitialized_shaped(generator, ctx, elem_type, shape_producer, name) +} + +/// Generates LLVM IR for `np.empty`. +pub fn gen_ndarray_empty<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let shape_ty = fun.0.args[0].ty; + let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; + + let ndarray = call_ndarray_empty_impl( + generator, + context, + context.primitives.float, + shape, + shape_ty, + None, + )?; + Ok(ndarray.ptr) +} + +/// Generates LLVM IR for `np.zeros`. +/// +/// NOTE: Current `dtype` is always `float64`. +pub fn gen_ndarray_zeros<'ctx>( + context: &mut CodeGenContext<'ctx, '_>, + obj: &Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: &[(Option, ValueEnum<'ctx>)], + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let shape_ty = fun.0.args[0].ty; + let shape = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?; + + // Allocate an ndarray and fill it later + let ndarray = call_ndarray_empty_impl( + generator, + context, + context.primitives.float, // float64 + shape, + shape_ty, + None, + )?; + + // TRICK: The float64 type could be conveniently extracted out of `ndarray` + let float_type = ndarray.ty.elem_type.into_float_type(); + + // Fill the ndarray + call_nac3_ndarray_fill_generic(context, ndarray, float_type.const_float(1.0).into()); + + // Return our ndarray + println!("ndarray.ptr = {}", ndarray.ptr); + Ok(ndarray.ptr) +} diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index 4ffd60b1..474962a7 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -23,3 +23,4 @@ pub mod codegen; pub mod symbol_resolver; pub mod toplevel; pub mod typecheck; +pub mod util; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e49748d9..84fa91f6 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,5 +1,6 @@ use std::iter::once; +use crate::util::SizeVariant; use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails}; use indexmap::IndexMap; use inkwell::{ @@ -278,19 +279,10 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built .collect() } -/// A helper enum used by [`BuiltinBuilder`] -#[derive(Clone, Copy)] -enum SizeVariant { - Bits32, - Bits64, -} - -impl SizeVariant { - fn of_int(self, primitives: &PrimitiveStore) -> Type { - match self { - SizeVariant::Bits32 => primitives.int32, - SizeVariant::Bits64 => primitives.int64, - } +fn size_variant_to_int_type(variant: SizeVariant, primitives: &PrimitiveStore) -> Type { + match variant { + SizeVariant::Bits32 => primitives.int32, + SizeVariant::Bits64 => primitives.int64, } } @@ -1061,7 +1053,7 @@ impl<'a> BuiltinBuilder<'a> { ); // The size variant of the function determines the size of the returned int. - let int_sized = size_variant.of_int(self.primitives); + let int_sized = size_variant_to_int_type(size_variant, self.primitives); let ndarray_int_sized = make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); @@ -1086,7 +1078,7 @@ impl<'a> BuiltinBuilder<'a> { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let ret_elem_ty = size_variant.of_int(&ctx.primitives); + let ret_elem_ty = size_variant_to_int_type(size_variant, &ctx.primitives); Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ret_elem_ty)?)) }), ) @@ -1127,7 +1119,7 @@ impl<'a> BuiltinBuilder<'a> { make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty)); // The size variant of the function determines the type of int returned - let int_sized = size_variant.of_int(self.primitives); + let int_sized = size_variant_to_int_type(size_variant, self.primitives); let ndarray_int_sized = make_ndarray_ty(self.unifier, self.primitives, Some(int_sized), Some(common_ndim.ty)); @@ -1150,7 +1142,7 @@ impl<'a> BuiltinBuilder<'a> { let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - let ret_elem_ty = size_variant.of_int(&ctx.primitives); + let ret_elem_ty = size_variant_to_int_type(size_variant, &ctx.primitives); let func = match kind { Kind::Ceil => builtin_fns::call_ceil, Kind::Floor => builtin_fns::call_floor, @@ -1201,14 +1193,13 @@ impl<'a> BuiltinBuilder<'a> { self.ndarray_float, &[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")], Box::new(move |ctx, obj, fun, args, generator| { - todo!() - // let func = match prim { - // PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty, - // PrimDef::FunNpZeros => gen_ndarray_zeros, - // PrimDef::FunNpOnes => gen_ndarray_ones, - // _ => unreachable!(), - // }; - // func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum())) + let func = match prim { + PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty, + PrimDef::FunNpZeros => gen_ndarray_zeros, + PrimDef::FunNpOnes => todo!(), // gen_ndarray_ones, + _ => unreachable!(), + }; + func(ctx, &obj, fun, &args, generator).map(|val| Some(val.as_basic_value_enum())) }), ) } diff --git a/nac3core/src/util.rs b/nac3core/src/util.rs new file mode 100644 index 00000000..99bc134f --- /dev/null +++ b/nac3core/src/util.rs @@ -0,0 +1,5 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SizeVariant { + Bits32, + Bits64, +} diff --git a/nac3standalone/demo/src/my_ndarray.py b/nac3standalone/demo/src/my_ndarray.py new file mode 100644 index 00000000..94693d26 --- /dev/null +++ b/nac3standalone/demo/src/my_ndarray.py @@ -0,0 +1,3 @@ +def run() -> int32: + hello = np_zeros((3, 4)) + return 0 \ No newline at end of file diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index c2a1d194..64752e91 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -449,6 +449,9 @@ fn main() { .create_target_machine(llvm_options.opt_level) .expect("couldn't create target machine"); + // NOTE: DEBUG PRINT + main.print_to_file("standalone.ll").unwrap(); + let pass_options = PassBuilderOptions::create(); pass_options.set_merge_functions(true); let passes = format!("default", opt_level as u32);