From 3f4ee433f1ab34f263fe91404c74a169025d0af7 Mon Sep 17 00:00:00 2001 From: lyken Date: Fri, 12 Jul 2024 18:00:58 +0800 Subject: [PATCH] WIP: core: save progress --- nac3core/irrt/irrt_error_context.hpp | 31 ++ nac3core/irrt/irrt_everything.hpp | 8 +- nac3core/irrt/irrt_printer.hpp | 82 +++++ nac3core/irrt/irrt_test.cpp | 12 + nac3core/irrt/irrt_utils.hpp | 37 ++ nac3core/src/codegen/classes.rs | 506 +++++++++------------------ nac3core/src/codegen/expr.rs | 1 + nac3core/src/codegen/irrt/classes.rs | 87 +++++ nac3core/src/codegen/irrt/mod.rs | 141 ++++++-- nac3core/src/codegen/mod.rs | 1 + nac3core/src/codegen/numpy.rs | 4 +- nac3core/src/codegen/structure.rs | 318 +++++++++++++++++ nac3core/src/toplevel/builtins.rs | 10 +- 13 files changed, 847 insertions(+), 391 deletions(-) create mode 100644 nac3core/irrt/irrt_error_context.hpp create mode 100644 nac3core/irrt/irrt_printer.hpp create mode 100644 nac3core/src/codegen/irrt/classes.rs create mode 100644 nac3core/src/codegen/structure.rs diff --git a/nac3core/irrt/irrt_error_context.hpp b/nac3core/irrt/irrt_error_context.hpp new file mode 100644 index 00000000..010f5d67 --- /dev/null +++ b/nac3core/irrt/irrt_error_context.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include "irrt_printer.hpp" + +namespace { + #define MAX_ERROR_NAME_LEN 32 + + // TODO: right now just to report some messages for now + struct ErrorContext { + Printer error; + // TODO: add error_class_name?? + + void initialize(char* string_base_ptr, uint32_t max_length) { + error.initialize(string_base_ptr, max_length); + } + + bool has_error() { + return error.length > 0; + } + }; +} + +extern "C" { + void __nac3_error_context_init(ErrorContext* ctx, char* string_base_ptr, uint32_t max_length) { + ctx->initialize(string_base_ptr, max_length); + } + + uint8_t __nac3_error_context_has_error(ErrorContext* ctx) { + return (uint8_t) ctx->has_error(); + } +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_everything.hpp b/nac3core/irrt/irrt_everything.hpp index 81e0bdc8..7dbdd608 100644 --- a/nac3core/irrt/irrt_everything.hpp +++ b/nac3core/irrt/irrt_everything.hpp @@ -1,10 +1,12 @@ #pragma once -#include "irrt_utils.hpp" -#include "irrt_typedefs.hpp" #include "irrt_basic.hpp" -#include "irrt_slice.hpp" +#include "irrt_error_context.hpp" #include "irrt_numpy_ndarray.hpp" +#include "irrt_printer.hpp" +#include "irrt_slice.hpp" +#include "irrt_typedefs.hpp" +#include "irrt_utils.hpp" /* All IRRT implementations. diff --git a/nac3core/irrt/irrt_printer.hpp b/nac3core/irrt/irrt_printer.hpp new file mode 100644 index 00000000..88d3598d --- /dev/null +++ b/nac3core/irrt/irrt_printer.hpp @@ -0,0 +1,82 @@ +#pragma once + +#include "irrt_typedefs.hpp" + +// TODO: obviously implementing printf from scratch is bad, +// is there a header only, no-cstdlib library for this? + +namespace { + struct Printer { + char* string_base_ptr; + uint32_t max_length; + uint32_t length; // NOTE: this could be incremented past max_length, which indicates + + void initialize(char *string_base_ptr, uint32_t max_length) { + this->string_base_ptr = string_base_ptr; + this->max_length = max_length; + this->length = 0; + } + + void put_space() { + put_char(' '); + } + + void put_char(char ch) { + push_char(ch); + } + + void put_string(const char* string) { + // TODO: optimize? + while (*string != '\0') { + push_char(*string); + string++; // Move to next char + } + } + + template + void put_int(T value) { + // NOTE: Try not to use recursion to print the digits + + // value == 0 is a special case + if (value == 0) { + push_char('0'); + } else { + // Add a '-' if the value is negative + if (value < 0) { + push_char('-'); + value = -value; // Negate then continue to print the digits + } + + // TODO: Recursion is a bad idea on embedded systems? + uint32_t num_digits = int_log_floor(value, 10) + 1; + put_int_helper(num_digits, value); + } + } + + // TODO: implement put_float() and more would be useful + private: + void push_char(char ch) { + if (length < max_length) { + string_base_ptr[length] = ch; + } + + // NOTE: this could increment past max_length, + // to indicate the true length of the message even if it gets cut off + length++; + } + + template + void put_int_helper(uint32_t num_digits, T value) { + // Print the digits recursively + __builtin_assume(0 <= value); + + if (num_digits > 0) { + put_int_helper(num_digits - 1, value / 10); + + uint32_t digit = value % 10; + char digit_char = '0' + (char) digit; + put_char(digit_char); + } + } + }; +} \ No newline at end of file diff --git a/nac3core/irrt/irrt_test.cpp b/nac3core/irrt/irrt_test.cpp index b05e0ac6..663af2c2 100644 --- a/nac3core/irrt/irrt_test.cpp +++ b/nac3core/irrt/irrt_test.cpp @@ -675,6 +675,17 @@ void test_ndarray_broadcast_1() { assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {1, 2, 3}))); } +void test_printer() { + const uint32_t buffer_len = 256; + char buffer[buffer_len]; + Printer printer = { + .string_base_ptr = buffer, + .max_length = buffer_len, + .length = 0 + }; + +} + int main() { test_calc_size_from_shape_normal(); test_calc_size_from_shape_has_zero(); @@ -691,5 +702,6 @@ int main() { test_ndslice_3(); test_can_broadcast_shape(); test_ndarray_broadcast_1(); + test_printer(); return 0; } \ No newline at end of file diff --git a/nac3core/irrt/irrt_utils.hpp b/nac3core/irrt/irrt_utils.hpp index 8d69b6a1..985d74a7 100644 --- a/nac3core/irrt/irrt_utils.hpp +++ b/nac3core/irrt/irrt_utils.hpp @@ -21,6 +21,43 @@ namespace { return true; } + template + uint32_t int_log_floor(T value, T base) { + uint32_t result = 0; + while (value < base) { + result++; + value /= base; + } + return result; + } + + bool string_is_empty(const char *str) { + return str[0] == '\0'; + } + + // TODO: DOCUMENT ME!!!!! + // returns false if `src_str` could not be fully copied over to `dst_str` + bool string_copy(uint32_t dst_max_size, char* dst_str, const char* src_str) { + // This function guarantess that `dst_str` will be null-terminated, + + for (uint32_t i = 0; i < dst_max_size; i++) { + bool is_last = i + 1 == dst_max_size; + if (is_last && src_str[i] != '\0') { + dst_str[i] = '\0'; + return false; + } + + if (src_str[i] == '\0') { + dst_str[i] = '\0'; + return true; + } + + dst_str[i] = src_str[i]; + } + + __builtin_unreachable(); + } + void irrt_panic() { // Crash the program for now. // TODO: Don't crash the program diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 04f26a5a..65aa857c 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -1768,357 +1768,163 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, { } -#[derive(Debug, Clone, Copy)] -pub struct StructField<'ctx> { - /// The GEP index of this struct field. - pub gep_index: u32, - /// Name of this struct field. - /// - /// Used for generating names. - pub name: &'static str, - /// The type of this struct field. - pub ty: BasicTypeEnum<'ctx>, -} +// #[derive(Debug, Clone, Copy)] +// pub struct StructField<'ctx> { +// /// The GEP index of this struct field. +// pub gep_index: u32, +// /// Name of this struct field. +// /// +// /// Used for generating names. +// pub name: &'static str, +// /// The type of this struct field. +// pub ty: BasicTypeEnum<'ctx>, +// } +// +// pub struct StructFields<'ctx> { +// /// Name of the struct. +// /// +// /// Used for generating names. +// pub name: &'static str, +// +// /// All the [`StructField`]s of this struct. +// /// +// /// **NOTE:** The index position of a [`StructField`] +// /// matches the element's [`StructField::index`]. +// pub fields: Vec>, +// } +// +// pub struct StructFieldsBuilder<'ctx> { +// gep_index_counter: u32, +// /// Name of the struct to be built. +// name: &'static str, +// fields: Vec>, +// } +// +// impl<'ctx> StructField<'ctx> { +// /// TODO: DOCUMENT ME +// pub fn gep( +// &self, +// ctx: &CodeGenContext<'ctx, '_>, +// struct_ptr: PointerValue<'ctx>, +// ) -> PointerValue<'ctx> { +// let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to use i32 for GEP like that +// unsafe { +// ctx.builder +// .build_in_bounds_gep( +// struct_ptr, +// &[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)], +// self.name, +// ) +// .unwrap() +// } +// } +// +// /// TODO: DOCUMENT ME +// pub fn load( +// &self, +// ctx: &CodeGenContext<'ctx, '_>, +// struct_ptr: PointerValue<'ctx>, +// ) -> BasicValueEnum<'ctx> { +// ctx.builder.build_load(self.gep(ctx, struct_ptr), self.name).unwrap() +// } +// +// /// TODO: DOCUMENT ME +// pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, struct_ptr: PointerValue<'ctx>, value: V) +// where +// V: BasicValue<'ctx>, +// { +// ctx.builder.build_store(self.gep(ctx, struct_ptr), value).unwrap(); +// } +// } -pub struct StructFields<'ctx> { - /// Name of the struct. - /// - /// Used for generating names. - pub name: &'static str, +// type IsInstanceError = String; +// type IsInstanceResult = Result<(), IsInstanceError>; - /// All the [`StructField`]s of this struct. - /// - /// **NOTE:** The index position of a [`StructField`] - /// matches the element's [`StructField::index`]. - pub fields: Vec>, -} +// pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> IsInstanceResult +// where +// A: BasicType<'ctx>, +// B: BasicType<'ctx>, +// { +// let expected = expected.as_basic_type_enum(); +// let got = got.as_basic_type_enum(); -pub struct StructFieldsBuilder<'ctx> { - gep_index_counter: u32, - /// Name of the struct to be built. - name: &'static str, - fields: Vec>, -} +// // Put those logic into here, +// // otherwise there is always a fallback reporting on any kind of mismatch +// match (expected, got) { +// (BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => { +// if expected.get_bit_width() != got.get_bit_width() { +// return Err(format!( +// "Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))" +// )); +// } +// } +// (expected, got) => { +// if expected != got { +// return Err(format!("Expected {expected}, got {got}")); +// } +// } +// } +// Ok(()) +// } -impl<'ctx> StructField<'ctx> { - /// TODO: DOCUMENT ME - pub fn gep( - &self, - ctx: &CodeGenContext<'ctx, '_>, - struct_ptr: PointerValue<'ctx>, - ) -> PointerValue<'ctx> { - let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to use i32 for GEP like that - unsafe { - ctx.builder - .build_in_bounds_gep( - struct_ptr, - &[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)], - self.name, - ) - .unwrap() - } - } - - /// TODO: DOCUMENT ME - pub fn load( - &self, - ctx: &CodeGenContext<'ctx, '_>, - struct_ptr: PointerValue<'ctx>, - ) -> BasicValueEnum<'ctx> { - ctx.builder.build_load(self.gep(ctx, struct_ptr), self.name).unwrap() - } - - /// TODO: DOCUMENT ME - pub fn store(&self, ctx: &CodeGenContext<'ctx, '_>, struct_ptr: PointerValue<'ctx>, value: V) - where - V: BasicValue<'ctx>, - { - ctx.builder.build_store(self.gep(ctx, struct_ptr), value).unwrap(); - } -} - -type IsInstanceError = String; -type IsInstanceResult = Result<(), IsInstanceError>; - -pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> IsInstanceResult -where - A: BasicType<'ctx>, - B: BasicType<'ctx>, -{ - let expected = expected.as_basic_type_enum(); - let got = got.as_basic_type_enum(); - - // Put those logic into here, - // otherwise there is always a fallback reporting on any kind of mismatch - match (expected, got) { - (BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => { - if expected.get_bit_width() != got.get_bit_width() { - return Err(format!( - "Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))" - )); - } - } - (expected, got) => { - if expected != got { - return Err(format!("Expected {expected}, got {got}")); - } - } - } - Ok(()) -} - -impl<'ctx> StructFields<'ctx> { - pub fn num_fields(&self) -> u32 { - self.fields.len() as u32 - } - - pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> { - let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec(); - ctx.struct_type(llvm_fields.as_slice(), false) - } - - pub fn is_type(&self, scrutinee: StructType<'ctx>) -> IsInstanceResult { - // Check scrutinee's number of struct fields - if scrutinee.count_fields() != self.num_fields() { - return Err(format!( - "Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}", - struct_name = self.name, - expected_count = self.num_fields(), - got_count = scrutinee.count_fields(), - )); - } - - // Check the scrutinee's field types - for field in self.fields.iter() { - let expected_field_ty = field.ty; - let got_field_ty = scrutinee.get_field_type_at_index(field.gep_index).unwrap(); - - if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) { - return Err(format!( - "Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}", - gep_index = field.gep_index, - struct_name = self.name, - field_name = field.name, - )); - } - } - - // Done - Ok(()) - } -} - -impl<'ctx> StructFieldsBuilder<'ctx> { - pub fn start(name: &'static str) -> Self { - StructFieldsBuilder { gep_index_counter: 0, name, fields: Vec::new() } - } - - pub fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> { - let index = self.gep_index_counter; - self.gep_index_counter += 1; - - let field = StructField { gep_index: index, name, ty }; - self.fields.push(field); // Register into self.fields - - field // Return to the caller to conveniently let them do whatever they want - } - - pub fn end(self) -> StructFields<'ctx> { - StructFields { name: self.name, fields: self.fields } - } -} - -// TODO: Use derppening's abstraction -#[derive(Debug, Clone, Copy)] -pub struct NpArrayType<'ctx> { - pub size_type: IntType<'ctx>, - pub elem_type: BasicTypeEnum<'ctx>, -} - -pub struct NpArrayStructFields<'ctx> { - pub whole_struct: StructFields<'ctx>, - pub data: StructField<'ctx>, - pub itemsize: StructField<'ctx>, - pub ndims: StructField<'ctx>, - pub shape: StructField<'ctx>, - pub strides: StructField<'ctx>, -} - -impl<'ctx> NpArrayType<'ctx> { - pub fn new_opaque_elem( - ctx: &CodeGenContext<'ctx, '_>, - size_type: IntType<'ctx>, - ) -> NpArrayType<'ctx> { - NpArrayType { size_type, elem_type: ctx.ctx.i8_type().as_basic_type_enum() } - } - - pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> { - self.fields(ctx).whole_struct.get_struct_type(ctx) - } - - pub fn fields(&self, ctx: &'ctx Context) -> NpArrayStructFields<'ctx> { - let mut builder = StructFieldsBuilder::start("NpArray"); - - let addrspace = AddressSpace::default(); - - let byte_type = ctx.i8_type(); - - // Make sure the struct matches PERFECTLY with that defined in `nac3core/irrt`. - let data = builder.add_field("data", byte_type.ptr_type(addrspace).into()); - let itemsize = builder.add_field("itemsize", self.size_type.into()); - let ndims = builder.add_field("ndims", self.size_type.into()); - let shape = builder.add_field("shape", self.size_type.ptr_type(addrspace).into()); - let strides = builder.add_field("strides", self.size_type.ptr_type(addrspace).into()); - - NpArrayStructFields { whole_struct: builder.end(), data, itemsize, ndims, shape, strides } - } - - /// Allocate an `ndarray` on stack, with the following notes: - /// - /// - `ndarray.ndims` will be initialized to `in_ndims`. - /// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`. - /// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`, - /// all with empty/uninitialized values. - pub fn alloca( - &self, - ctx: &mut CodeGenContext<'ctx, '_>, - in_ndims: IntValue<'ctx>, - name: &str, - ) -> NpArrayValue<'ctx> { - let ptr = ctx - .builder - .build_alloca(self.get_struct_type(ctx.ctx).as_basic_type_enum(), name) - .unwrap(); - - // Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides` - let allocated_shape = ctx - .builder - .build_array_alloca(self.size_type.as_basic_type_enum(), in_ndims, "allocated_shape") - .unwrap(); - let allocated_strides = ctx - .builder - .build_array_alloca(self.size_type.as_basic_type_enum(), in_ndims, "allocated_strides") - .unwrap(); - - let value = NpArrayValue { ty: *self, ptr }; - value.store_ndims(ctx, in_ndims); - value.store_itemsize(ctx, self.elem_type.size_of().unwrap()); - value.store_shape(ctx, allocated_shape); - value.store_strides(ctx, allocated_strides); - - return value; - } - - pub fn value_from_ptr( - &self, - ctx: &'ctx Context, - in_ndarray_ptr: PointerValue<'ctx>, - ) -> NpArrayValue<'ctx> { - if cfg!(debug_assertions) { - // Sanity check on `in_ndarray_ptr`'s type - - let in_ndarray_struct_type = - in_ndarray_ptr.get_type().get_element_type().into_struct_type(); - - // unwrap to check - self.fields(ctx).whole_struct.is_type(in_ndarray_struct_type).unwrap(); - } - NpArrayValue { ty: *self, ptr: in_ndarray_ptr } - } -} - -#[derive(Debug, Clone, Copy)] -pub struct NpArrayValue<'ctx> { - pub ty: NpArrayType<'ctx>, - pub ptr: PointerValue<'ctx>, -} - -impl<'ctx> NpArrayValue<'ctx> { - pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let field = self.ty.fields(ctx.ctx).data; - field.load(ctx, self.ptr).into_pointer_value() - } - - pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, new_data_ptr: PointerValue<'ctx>) { - let field = self.ty.fields(ctx.ctx).data; - field.store(ctx, self.ptr, new_data_ptr); - } - - pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - let field = self.ty.fields(ctx.ctx).ndims; - field.load(ctx, self.ptr).into_int_value() - } - - pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, new_ndims: IntValue<'ctx>) { - let field = self.ty.fields(ctx.ctx).ndims; - field.store(ctx, self.ptr, new_ndims); - } - - pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - let field = self.ty.fields(ctx.ctx).itemsize; - field.load(ctx, self.ptr).into_int_value() - } - - pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, new_itemsize: IntValue<'ctx>) { - let field = self.ty.fields(ctx.ctx).itemsize; - field.store(ctx, self.ptr, new_itemsize); - } - - pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let field = self.ty.fields(ctx.ctx).shape; - field.load(ctx, self.ptr).into_pointer_value() - } - - pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, new_shape_ptr: PointerValue<'ctx>) { - let field = self.ty.fields(ctx.ctx).shape; - field.store(ctx, self.ptr, new_shape_ptr); - } - - pub fn load_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { - let field = self.ty.fields(ctx.ctx).strides; - field.load(ctx, self.ptr).into_pointer_value() - } - - pub fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) { - let field = self.ty.fields(ctx.ctx).strides; - field.store(ctx, self.ptr, value); - } - - /// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!! - pub fn shape_slice( - &self, - ctx: &CodeGenContext<'ctx, '_>, - ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - // Get the pointer to `shape` - let field = self.ty.fields(ctx.ctx).shape; - let shape = field.load(ctx, self.ptr).into_pointer_value(); - - // Load `ndims` - let ndims = self.load_ndims(ctx); - - TypedArrayLikeAdapter { - adapted: ArraySliceValue(shape, ndims, Some(field.name)), - downcast_fn: Box::new(|_ctx, x| x.into_int_value()), - upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()), - } - } - - /// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!! - pub fn strides_slice( - &self, - ctx: &CodeGenContext<'ctx, '_>, - ) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> { - // Get the pointer to `strides` - let field = self.ty.fields(ctx.ctx).strides; - let strides = field.load(ctx, self.ptr).into_pointer_value(); - - // Load `ndims` - let ndims = self.load_ndims(ctx); - - TypedArrayLikeAdapter { - adapted: ArraySliceValue(strides, ndims, Some(field.name)), - downcast_fn: Box::new(|_ctx, x| x.into_int_value()), - upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()), - } - } -} +// impl<'ctx> StructFields<'ctx> { +// pub fn num_fields(&self) -> u32 { +// self.fields.len() as u32 +// } +// +// pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> { +// let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec(); +// ctx.struct_type(llvm_fields.as_slice(), false) +// } +// +// pub fn is_type(&self, scrutinee: StructType<'ctx>) -> IsInstanceResult { +// // Check scrutinee's number of struct fields +// if scrutinee.count_fields() != self.num_fields() { +// return Err(format!( +// "Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}", +// struct_name = self.name, +// expected_count = self.num_fields(), +// got_count = scrutinee.count_fields(), +// )); +// } +// +// // Check the scrutinee's field types +// for field in self.fields.iter() { +// let expected_field_ty = field.ty; +// let got_field_ty = scrutinee.get_field_type_at_index(field.gep_index).unwrap(); +// +// if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) { +// return Err(format!( +// "Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}", +// gep_index = field.gep_index, +// struct_name = self.name, +// field_name = field.name, +// )); +// } +// } +// +// // Done +// Ok(()) +// } +// } +// +// impl<'ctx> StructFieldsBuilder<'ctx> { +// pub fn start(name: &'static str) -> Self { +// StructFieldsBuilder { gep_index_counter: 0, name, fields: Vec::new() } +// } +// +// pub fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> { +// let index = self.gep_index_counter; +// self.gep_index_counter += 1; +// +// let field = StructField { gep_index: index, name, ty }; +// self.fields.push(field); // Register into self.fields +// +// field // Return to the caller to conveniently let them do whatever they want +// } +// +// pub fn end(self) -> StructFields<'ctx> { +// StructFields { name: self.name, fields: self.fields } +// } +// } +// diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 9a28d42f..9c786b94 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -2202,6 +2202,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( let dst_ndims = deduce_ndims_after_slicing(ndims, ndslices.iter()); // Finally, perform the actual subscript logic + // TODO: `call_ndarray_subscript_impl` under the hood deduces `dst_ndims` again. We could save it some time by passing `dst_ndims` - a TODO? let subndarray = call_ndarray_subscript_impl( generator, ctx, diff --git a/nac3core/src/codegen/irrt/classes.rs b/nac3core/src/codegen/irrt/classes.rs new file mode 100644 index 00000000..c68b5567 --- /dev/null +++ b/nac3core/src/codegen/irrt/classes.rs @@ -0,0 +1,87 @@ +// TODO: Use derppening's abstraction + +use std::marker::PhantomData; + +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, IntType}, + values::BasicValueEnum, + AddressSpace, +}; + +use crate::codegen::structure::{ + CustomStructType, CustomType, Field, FieldCreator, IntType2, Object, PointerType2, + PointingArrayType, +}; + +#[derive(Debug, Clone, Copy)] +pub struct NpArrayType<'ctx> { + pub size_type: IntType<'ctx>, + pub elem_type: BasicTypeEnum<'ctx>, +} + +pub struct NpArrayFields<'ctx> { + pub data: Field<'ctx, PointerType2<'ctx>>, + pub itemsize: Field<'ctx, IntType2<'ctx>>, + pub ndims: Field<'ctx, IntType2<'ctx>>, + pub shape: Field<'ctx, PointingArrayType<'ctx, IntType2<'ctx>>>, + pub strides: Field<'ctx, PointingArrayType<'ctx, IntType2<'ctx>>>, +} + +pub type NpArrayValue<'ctx> = Object<'ctx, NpArrayType<'ctx>>; + +// impl<'ctx> CustomType<'ctx> for NpArrayType<'ctx> { +// type Value = NpArrayValue<'ctx>; +// +// fn llvm_basic_type_enum( +// &self, +// ctx: &'ctx inkwell::context::Context, +// ) -> inkwell::types::BasicTypeEnum<'ctx> { +// self.llvm_struct_type(ctx).as_basic_type_enum() +// } +// +// fn llvm_field_load( +// &self, +// ctx: &crate::codegen::CodeGenContext<'ctx, '_>, +// field: crate::codegen::structure::FieldInfo, +// struct_ptr: inkwell::values::PointerValue<'ctx>, +// ) -> Self::Value { +// let ok = field.llvm_load(ctx, struct_ptr); +// todo!() +// } +// +// fn llvm_field_store( +// &self, +// ctx: &crate::codegen::CodeGenContext<'ctx, '_>, +// field: crate::codegen::structure::FieldInfo, +// struct_ptr: inkwell::values::PointerValue<'ctx>, +// value: &Self::Value, +// ) { +// todo!() +// } +// } + +impl<'ctx> CustomStructType<'ctx> for NpArrayType<'ctx> { + type Fields = NpArrayFields<'ctx>; + + fn llvm_struct_name() -> &'static str { + "NDArray" + } + + fn add_fields_to(&self, creator: &mut FieldCreator<'ctx>) -> Self::Fields { + let pi8 = creator.ctx.i8_type().ptr_type(AddressSpace::default()); + NpArrayFields { + data: creator.add_field("data", PointerType2(pi8)), + itemsize: creator.add_field("itemsize", IntType2(self.size_type)), + ndims: creator.add_field("ndims", IntType2(self.size_type)), + shape: creator.add_field("shape", PointingArrayType::new(IntType2(self.size_type))), + strides: creator.add_field("strides", PointingArrayType::new(IntType2(self.size_type))), + } + } +} + +impl<'ctx> NpArrayType<'ctx> { + pub fn new_opaque_elem(ctx: &'ctx Context, size_type: IntType<'ctx>) -> Self { + NpArrayType { elem_type: ctx.i8_type().into(), size_type } + } +} diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 7618463a..4336f20c 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -11,13 +11,15 @@ mod test; use super::{ classes::{ check_basic_types_match, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, - NDArrayValue, NpArrayType, NpArrayValue, StructField, StructFields, TypedArrayLikeAdapter, - UntypedArrayLikeAccessor, + NDArrayValue, StructField, StructFields, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, }, - llvm_intrinsics, CodeGenContext, CodeGenerator, + llvm_intrinsics, + structure::CustomStructType, + CodeGenContext, CodeGenerator, }; use crate::codegen::classes::TypedArrayLikeAccessor; use crate::codegen::stmt::gen_for_callback_incrementing; +use classes::{NpArrayType, NpArrayValue}; use crossbeam::channel::IntoIter; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -33,6 +35,7 @@ use inkwell::{ }; use itertools::Either; use nac3parser::ast::Expr; +pub mod classes; #[must_use] pub fn load_irrt(ctx: &Context) -> Module { @@ -957,7 +960,7 @@ pub struct UserSlice<'ctx> { pub step: Option>, } -pub struct IrrtUserSliceStructFields<'ctx> { +pub struct IrrtUserSliceTypeStructFields<'ctx> { pub whole_struct: StructFields<'ctx>, pub start_defined: StructField<'ctx>, @@ -971,10 +974,10 @@ pub struct IrrtUserSliceStructFields<'ctx> { } // TODO: EMPTY STRUCT -struct IrrtUserSlice {} +struct IrrtUserSliceType(); -impl IrrtUserSlice { - pub fn fields<'ctx>(ctx: &'ctx Context) -> IrrtUserSliceStructFields<'ctx> { +impl IrrtUserSliceType { + pub fn fields<'ctx>(&self, ctx: &'ctx Context) -> IrrtUserSliceTypeStructFields<'ctx> { let int8 = ctx.i8_type(); // MUST match the corresponding struct defined in IRRT @@ -986,7 +989,7 @@ impl IrrtUserSlice { let step_defined = builder.add_field("step_defined", int8.into()); let step = builder.add_field("step", get_sliceindex_type(ctx).into()); - IrrtUserSliceStructFields { + IrrtUserSliceTypeStructFields { start_defined, start, stop_defined, @@ -998,11 +1001,12 @@ impl IrrtUserSlice { } pub fn alloca_user_slice<'ctx>( + &self, ctx: &CodeGenContext<'ctx, '_>, user_slice: &UserSlice<'ctx>, ) -> PointerValue<'ctx> { // Derive the struct_type - let fields = Self::fields(ctx.ctx); + let fields = self.fields(ctx.ctx); let struct_type = fields.whole_struct.get_struct_type(ctx.ctx); // ...and then allocate for a real `UserSlice` in LLVM @@ -1070,23 +1074,22 @@ where } // TODO: Empty struct -pub struct IrrtNDSlice {} - -pub struct IrrtNDSliceStructFields<'ctx> { +pub struct IrrtNDSliceType(); +pub struct IrrtNDSliceTypeStructFields<'ctx> { pub whole_struct: StructFields<'ctx>, pub type_: StructField<'ctx>, pub slice: StructField<'ctx>, } -impl IrrtNDSlice { - pub fn fields<'ctx>(ctx: &'ctx Context) -> IrrtNDSliceStructFields<'ctx> { +impl IrrtNDSliceType { + pub fn fields<'ctx>(ctx: &'ctx Context) -> IrrtNDSliceTypeStructFields<'ctx> { let mut builder = StructFieldsBuilder::start("NDSlice"); // MUST match the corresponding struct defined in IRRT let type_ = builder.add_field("type", get_ndslicetype_constant_type(ctx).into()); let slice = builder.add_field("slice", get_opaque_uint8_ptr_type(ctx).into()); - IrrtNDSliceStructFields { type_, slice, whole_struct: builder.end() } + IrrtNDSliceTypeStructFields { type_, slice, whole_struct: builder.end() } } pub fn alloca_ndslices<'ctx>( @@ -1124,7 +1127,7 @@ impl IrrtNDSlice { } NDSlice::Slice(user_slice) => { // Allocate the user_slice - let slice_ptr = IrrtUserSlice::alloca_user_slice(ctx, user_slice); + let slice_ptr = IrrtUserSliceType().alloca_user_slice(ctx, user_slice); let type_ = 1; // const NDSliceType INPUT_SLICE_TYPE_SLICE = 1; (type_, slice_ptr) @@ -1156,6 +1159,84 @@ impl IrrtNDSlice { } } +struct IrrtPrinterType(); +struct IrrtPrinterTypeStructFields<'ctx> { + whole_struct: StructFields<'ctx>, + string_base_ptr: StructField<'ctx>, + max_length: StructField<'ctx>, + length: StructField<'ctx>, +} + +impl IrrtPrinterType { + pub fn fields<'ctx>(&self, ctx: &'ctx Context) -> IrrtPrinterTypeStructFields<'ctx> { + let mut builder = StructFieldsBuilder::start("Printer"); + + let string_base_ptr = builder + .add_field("string_base_ptr", ctx.i8_type().ptr_type(AddressSpace::default()).into()); + let max_length = builder.add_field("max_length", ctx.i32_type().into()); + let length = builder.add_field("length", ctx.i32_type().into()); + + IrrtPrinterTypeStructFields { + string_base_ptr, + max_length, + length, + whole_struct: builder.end(), + } + } +} + +struct IrrtPrinterValue<'ctx> { + ty: IrrtPrinterType, + ptr: PointerValue<'ctx>, +} + +impl<'ctx> IrrtPrinterValue<'ctx> { + pub fn hl(&self) {} +} + +struct IrrtErrorContextType(); +struct IrrtErrorContextTypeStructFields<'ctx> { + whole_struct: StructFields<'ctx>, + error: StructField<'ctx>, +} + +impl IrrtErrorContextType { + pub fn fields<'ctx>(&self, ctx: &'ctx Context) -> IrrtErrorContextTypeStructFields<'ctx> { + let mut builder = StructFieldsBuilder::start("ErrorContext"); + + let error = builder.add_field( + "error", + IrrtPrinterType().fields(ctx).whole_struct.get_struct_type(ctx).into(), + ); + + IrrtErrorContextTypeStructFields { error, whole_struct: builder.end() } + } + + pub fn alloca<'ctx>(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> { + let fields = self.fields(ctx.ctx); + let struct_type = fields.whole_struct.get_struct_type(ctx.ctx); + + ctx.builder.build_alloca(struct_type, "error_context").unwrap() + } + + pub fn load_error<'ctx>( + &self, + ctx: &CodeGenContext<'ctx, '_>, + struct_ptr: PointerValue<'ctx>, + ) -> IrrtPrinterValue<'ctx> { + let error = self.fields(ctx.ctx).error; + + IrrtPrinterValue { + ty: IrrtPrinterType(), + ptr: error.load(ctx, struct_ptr).into_pointer_value(), + } + } +} + +// struct IrrtErrorContextValue<'ctx> { +// ty: IrrtErrorContextValue +// } + fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant { match ty.get_bit_width() { 32 => SizeVariant::Bits32, @@ -1198,8 +1279,7 @@ pub fn get_irrt_ndarray_ptr_type<'ctx>( let i8_type = ctx.i8_type(); let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() }; - let struct_ty = ndarray_ty.get_struct_type(ctx); - struct_ty.ptr_type(AddressSpace::default()) + ndarray_ty.llvm_struct_type(ctx).ptr_type(AddressSpace::default()) } pub fn get_opaque_uint8_ptr_type<'ctx>(ctx: &'ctx Context) -> PointerType<'ctx> { @@ -1317,7 +1397,7 @@ pub fn call_nac3_ndarray_deduce_ndims_after_slicing<'ctx>( &[ size_type.into(), // SizeT ndims size_type.into(), // SizeT num_slices - IrrtNDSlice::fields(ctx.ctx) + IrrtNDSliceType::fields(ctx.ctx) .whole_struct .get_struct_type(ctx.ctx) .ptr_type(AddressSpace::default()) @@ -1360,7 +1440,7 @@ pub fn call_nac3_ndarray_subscript<'ctx>( &[ get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray* ndarray size_type.into(), // SizeT num_slices - IrrtNDSlice::fields(ctx.ctx) + IrrtNDSliceType::fields(ctx.ctx) .whole_struct .get_struct_type(ctx.ctx) .ptr_type(AddressSpace::default()) @@ -1393,19 +1473,14 @@ pub fn call_nac3_len<'ctx>( let size_type = ndarray.ty.size_type; // Get the IRRT function - let function = get_size_type_dependent_function( - ctx, - size_type, - "__nac3_ndarray_len", - || { - get_sliceindex_type(ctx.ctx).fn_type( - &[ - get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray *ndarray - ], - false, - ) - }, - ); + let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_len", || { + get_sliceindex_type(ctx.ctx).fn_type( + &[ + get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray *ndarray + ], + false, + ) + }); // Call the IRRT function ctx.builder diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 5c2da34a..170c4713 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -44,6 +44,7 @@ pub mod irrt; pub mod llvm_intrinsics; pub mod numpy; pub mod stmt; +pub mod structure; #[cfg(test)] mod test; diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index ba89966c..9c040626 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -38,7 +38,7 @@ use super::{ irrt::{ call_nac3_ndarray_deduce_ndims_after_slicing, call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_size, call_nac3_ndarray_subscript, get_irrt_ndarray_ptr_type, - get_opaque_uint8_ptr_type, IrrtNDSlice, NDSlice, + get_opaque_uint8_ptr_type, IrrtNDSliceType, NDSlice, }, stmt::gen_return, }; @@ -2364,7 +2364,7 @@ where let num_slices = size_type.const_int(ndslices.len() as u64, false); // Prepare the argument `slices` - let ndslices_ptr = IrrtNDSlice::alloca_ndslices(ctx, ndslices); + let ndslices_ptr = IrrtNDSliceType::alloca_ndslices(ctx, ndslices); // Get `dst_ndims` let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing( diff --git a/nac3core/src/codegen/structure.rs b/nac3core/src/codegen/structure.rs new file mode 100644 index 00000000..4d00c3df --- /dev/null +++ b/nac3core/src/codegen/structure.rs @@ -0,0 +1,318 @@ +use std::marker::PhantomData; + +use inkwell::{ + context::Context, + types::{BasicType, BasicTypeEnum, IntType, PointerType, StructType}, + values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, + AddressSpace, +}; + +use super::CodeGenContext; + +#[derive(Debug, Clone, Copy)] +pub struct FieldInfo { + gep_index: u32, + name: &'static str, +} + +impl FieldInfo { + pub fn llvm_gep<'ctx>( + &self, + ctx: &CodeGenContext<'ctx, '_>, + struct_ptr: PointerValue<'ctx>, + ) -> PointerValue<'ctx> { + let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to *just* use i32 for GEP like that + unsafe { + ctx.builder + .build_in_bounds_gep( + struct_ptr, + &[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)], + self.name, + ) + .unwrap() + } + } + + pub fn llvm_load<'ctx>( + &self, + ctx: &CodeGenContext<'ctx, '_>, + struct_ptr: PointerValue<'ctx>, + ) -> BasicValueEnum<'ctx> { + // We will use `self.name` as the LLVM label for debugging purposes + ctx.builder.build_load(self.llvm_gep(ctx, struct_ptr), self.name).unwrap() + } + + pub fn llvm_store<'ctx>( + &self, + ctx: &CodeGenContext<'ctx, '_>, + struct_ptr: PointerValue<'ctx>, + value: BasicValueEnum<'ctx>, + ) { + ctx.builder.build_store(self.llvm_gep(ctx, struct_ptr), value).unwrap(); + } +} + +pub struct Object<'ctx, T> { + pub ty: T, + pub ptr: PointerValue<'ctx>, +} + +pub struct Field<'ctx, T: CustomType<'ctx>> { + pub info: FieldInfo, + pub ty: T, + _phantom: PhantomData<&'ctx ()>, +} + +pub struct FieldCreator<'ctx> { + pub ctx: &'ctx Context, + struct_name: &'ctx str, + gep_index_counter: u32, + fields: Vec<(FieldInfo, BasicTypeEnum<'ctx>)>, +} + +impl<'ctx> FieldCreator<'ctx> { + pub fn new(ctx: &'ctx Context, struct_name: &'ctx str) -> Self { + FieldCreator { ctx, struct_name, gep_index_counter: 0, fields: Vec::new() } + } + + fn next_gep_index(&mut self) -> u32 { + let index = self.gep_index_counter; + self.gep_index_counter += 1; + index + } + + fn get_struct_field_types(&self) -> Vec> { + self.fields.iter().map(|x| x.1.clone()).collect() + } + + pub fn add_field>(&mut self, name: &'static str, ty: T) -> Field<'ctx, T> { + let gep_index = self.next_gep_index(); + + let field_type = ty.llvm_basic_type_enum(self.ctx); + let field_info = FieldInfo { gep_index, name }; + let field = Field { info: field_info, ty, _phantom: PhantomData }; + + self.fields.push((field_info.clone(), field_type)); + + field + } + + fn num_fields(&self) -> u32 { + self.fields.len() as u32 // casted to u32 because that is what inkwell returns + } +} + +pub trait CustomType<'ctx>: Clone { + type Value; + + fn llvm_basic_type_enum(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>; + + fn llvm_field_load( + &self, + ctx: &CodeGenContext<'ctx, '_>, + field: FieldInfo, + struct_ptr: PointerValue<'ctx>, + ) -> Self::Value; + + fn llvm_field_store( + &self, + ctx: &CodeGenContext<'ctx, '_>, + field: FieldInfo, + struct_ptr: PointerValue<'ctx>, + value: &Self::Value, + ); +} + +#[derive(Debug, Clone, Copy)] +pub struct IntType2<'ctx>(pub IntType<'ctx>); + +impl<'ctx> CustomType<'ctx> for IntType2<'ctx> { + type Value = IntValue<'ctx>; + + fn llvm_basic_type_enum(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.as_basic_type_enum() + } + + fn llvm_field_load( + &self, + ctx: &CodeGenContext<'ctx, '_>, + field: FieldInfo, + struct_ptr: PointerValue<'ctx>, + ) -> Self::Value { + let int_value = field.llvm_load(ctx, struct_ptr).into_int_value(); + assert_eq!(int_value.get_type().get_bit_width(), self.0.get_bit_width()); + int_value + } + + fn llvm_field_store( + &self, + ctx: &CodeGenContext<'ctx, '_>, + field: FieldInfo, + struct_ptr: PointerValue<'ctx>, + int_value: &Self::Value, + ) { + assert_eq!(int_value.get_type().get_bit_width(), self.0.get_bit_width()); + field.llvm_store(ctx, struct_ptr, int_value.as_basic_value_enum()); + } +} + +#[derive(Debug, Clone, Copy)] +pub struct PointerType2<'ctx>(pub PointerType<'ctx>); + +impl<'ctx> CustomType<'ctx> for PointerType2<'ctx> { + type Value = PointerValue<'ctx>; + + fn llvm_basic_type_enum(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + self.0.as_basic_type_enum() + } + + fn llvm_field_load( + &self, + ctx: &CodeGenContext<'ctx, '_>, + field: FieldInfo, + struct_ptr: PointerValue<'ctx>, + ) -> Self::Value { + field.llvm_load(ctx, struct_ptr).into_pointer_value() + } + + fn llvm_field_store( + &self, + ctx: &CodeGenContext<'ctx, '_>, + field: FieldInfo, + struct_ptr: PointerValue<'ctx>, + pointer_value: &Self::Value, + ) { + field.llvm_store(ctx, struct_ptr, pointer_value.as_basic_value_enum()); + } +} + +#[derive(Debug, Clone, Copy)] +pub struct PointingArrayType<'ctx, ElementType: CustomType<'ctx>> { + pub element_type: ElementType, + _phantom: PhantomData<&'ctx ()>, +} + +impl<'ctx, ElementType: CustomType<'ctx>> PointingArrayType<'ctx, ElementType> { + pub fn new(element_type: ElementType) -> Self { + PointingArrayType { element_type, _phantom: PhantomData } + } +} + +impl<'ctx, Element: CustomType<'ctx>> CustomType<'ctx> for PointingArrayType<'ctx, Element> { + type Value = Object<'ctx, Self>; + + fn llvm_basic_type_enum(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> { + // Element* + self.element_type + .llvm_basic_type_enum(ctx) + .ptr_type(AddressSpace::default()) + .as_basic_type_enum() + } + + fn llvm_field_load( + &self, + ctx: &CodeGenContext<'ctx, '_>, + field: FieldInfo, + struct_ptr: PointerValue<'ctx>, + ) -> Self::Value { + // Remember that it is just a pointer + Object { ty: self.clone(), ptr: field.llvm_load(ctx, struct_ptr).into_pointer_value() } + } + + fn llvm_field_store( + &self, + ctx: &CodeGenContext<'ctx, '_>, + field: FieldInfo, + struct_ptr: PointerValue<'ctx>, + value: &Self::Value, + ) { + // Remember that it is just a pointer + todo!() + } +} + +pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> Result<(), String> +where + A: BasicType<'ctx>, + B: BasicType<'ctx>, +{ + let expected = expected.as_basic_type_enum(); + let got = got.as_basic_type_enum(); + + // Put those logic into here, + // otherwise there is always a fallback reporting on any kind of mismatch + match (expected, got) { + (BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => { + if expected.get_bit_width() != got.get_bit_width() { + return Err(format!( + "Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))" + )); + } + } + (expected, got) => { + if expected != got { + return Err(format!("Expected {expected}, got {got}")); + } + } + } + Ok(()) +} + +pub trait CustomStructType<'ctx> { + type Fields; + + fn llvm_struct_name() -> &'static str; + + fn add_fields_to(&self, creator: &mut FieldCreator<'ctx>) -> Self::Fields; + + fn fields(&self, ctx: &'ctx Context) -> Self::Fields { + let mut creator = FieldCreator::new(ctx, Self::llvm_struct_name()); + let fields = self.add_fields_to(&mut creator); + fields + } + + fn llvm_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> { + let mut creator = FieldCreator::new(ctx, Self::llvm_struct_name()); + self.add_fields_to(&mut creator); + + ctx.struct_type(&creator.get_struct_field_types(), false) + } + + fn check_struct_type( + &self, + ctx: &'ctx Context, + scrutinee: StructType<'ctx>, + ) -> Result<(), String> { + let mut creator = FieldCreator::new(ctx, Self::llvm_struct_name()); + self.add_fields_to(&mut creator); + + // Check scrutinee's number of struct fields + let expected_field_count = creator.num_fields(); + let got_field_count = scrutinee.count_fields(); + if got_field_count != expected_field_count { + return Err(format!( + "Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}", + struct_name = Self::llvm_struct_name(), + expected_count = expected_field_count, + got_count = got_field_count, + )); + } + + // Check the scrutinee's field types + for (field_info, expected_field_ty) in creator.fields { + let got_field_ty = scrutinee.get_field_type_at_index(field_info.gep_index).unwrap(); + + if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) { + return Err(format!( + "Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}", + gep_index = field_info.gep_index, + struct_name = Self::llvm_struct_name(), + field_name = field_info.name, + )); + } + } + + // Done + Ok(()) + } +} diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index abaaae24..c156e91e 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1,6 +1,7 @@ use std::iter::once; -use crate::{codegen::classes::NpArrayType, util::SizeVariant}; +use crate::util::SizeVariant; +use classes::NpArrayType; use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails}; use indexmap::IndexMap; use inkwell::{ @@ -1466,10 +1467,13 @@ impl<'a> BuiltinBuilder<'a> { let ndarray_ptr = arg.into_pointer_value(); // It has to be an ndarray let size_type = generator.get_size_type(ctx.ctx); - let ndarray_ty = NpArrayType::new_opaque_elem(ctx, size_type); // We don't need to care about the element type - we only want the shape + let ndarray_ty = NpArrayType::new_opaque_elem(ctx.ctx, size_type); // We don't need to care about the element type - we only want the shape let ndarray = ndarray_ty.value_from_ptr(ctx.ctx, ndarray_ptr); - Some(call_nac3_len(ctx, ndarray).as_basic_value_enum()) + let result = call_nac3_len(ctx, ndarray).as_basic_value_enum(); + Some(result) + + // Some(.as_basic_value_enum()) // let llvm_i32 = ctx.ctx.i32_type(); // let llvm_usize = generator.get_size_type(ctx.ctx);