forked from M-Labs/nac3
1
0
Fork 0

WIP: core: save progress

This commit is contained in:
lyken 2024-07-12 18:00:58 +08:00
parent ab7ff2ae9d
commit 3f4ee433f1
13 changed files with 847 additions and 391 deletions

View File

@ -0,0 +1,31 @@
#pragma once
#include "irrt_printer.hpp"
namespace {
#define MAX_ERROR_NAME_LEN 32
// TODO: right now just to report some messages for now
struct ErrorContext {
Printer error;
// TODO: add error_class_name??
void initialize(char* string_base_ptr, uint32_t max_length) {
error.initialize(string_base_ptr, max_length);
}
bool has_error() {
return error.length > 0;
}
};
}
extern "C" {
void __nac3_error_context_init(ErrorContext* ctx, char* string_base_ptr, uint32_t max_length) {
ctx->initialize(string_base_ptr, max_length);
}
uint8_t __nac3_error_context_has_error(ErrorContext* ctx) {
return (uint8_t) ctx->has_error();
}
}

View File

@ -1,10 +1,12 @@
#pragma once #pragma once
#include "irrt_utils.hpp"
#include "irrt_typedefs.hpp"
#include "irrt_basic.hpp" #include "irrt_basic.hpp"
#include "irrt_slice.hpp" #include "irrt_error_context.hpp"
#include "irrt_numpy_ndarray.hpp" #include "irrt_numpy_ndarray.hpp"
#include "irrt_printer.hpp"
#include "irrt_slice.hpp"
#include "irrt_typedefs.hpp"
#include "irrt_utils.hpp"
/* /*
All IRRT implementations. All IRRT implementations.

View File

@ -0,0 +1,82 @@
#pragma once
#include "irrt_typedefs.hpp"
// TODO: obviously implementing printf from scratch is bad,
// is there a header only, no-cstdlib library for this?
namespace {
struct Printer {
char* string_base_ptr;
uint32_t max_length;
uint32_t length; // NOTE: this could be incremented past max_length, which indicates
void initialize(char *string_base_ptr, uint32_t max_length) {
this->string_base_ptr = string_base_ptr;
this->max_length = max_length;
this->length = 0;
}
void put_space() {
put_char(' ');
}
void put_char(char ch) {
push_char(ch);
}
void put_string(const char* string) {
// TODO: optimize?
while (*string != '\0') {
push_char(*string);
string++; // Move to next char
}
}
template<typename T>
void put_int(T value) {
// NOTE: Try not to use recursion to print the digits
// value == 0 is a special case
if (value == 0) {
push_char('0');
} else {
// Add a '-' if the value is negative
if (value < 0) {
push_char('-');
value = -value; // Negate then continue to print the digits
}
// TODO: Recursion is a bad idea on embedded systems?
uint32_t num_digits = int_log_floor(value, 10) + 1;
put_int_helper(num_digits, value);
}
}
// TODO: implement put_float() and more would be useful
private:
void push_char(char ch) {
if (length < max_length) {
string_base_ptr[length] = ch;
}
// NOTE: this could increment past max_length,
// to indicate the true length of the message even if it gets cut off
length++;
}
template <typename T>
void put_int_helper(uint32_t num_digits, T value) {
// Print the digits recursively
__builtin_assume(0 <= value);
if (num_digits > 0) {
put_int_helper(num_digits - 1, value / 10);
uint32_t digit = value % 10;
char digit_char = '0' + (char) digit;
put_char(digit_char);
}
}
};
}

View File

@ -675,6 +675,17 @@ void test_ndarray_broadcast_1() {
assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {1, 2, 3}))); assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {1, 2, 3})));
} }
void test_printer() {
const uint32_t buffer_len = 256;
char buffer[buffer_len];
Printer printer = {
.string_base_ptr = buffer,
.max_length = buffer_len,
.length = 0
};
}
int main() { int main() {
test_calc_size_from_shape_normal(); test_calc_size_from_shape_normal();
test_calc_size_from_shape_has_zero(); test_calc_size_from_shape_has_zero();
@ -691,5 +702,6 @@ int main() {
test_ndslice_3(); test_ndslice_3();
test_can_broadcast_shape(); test_can_broadcast_shape();
test_ndarray_broadcast_1(); test_ndarray_broadcast_1();
test_printer();
return 0; return 0;
} }

View File

@ -21,6 +21,43 @@ namespace {
return true; return true;
} }
template<typename T>
uint32_t int_log_floor(T value, T base) {
uint32_t result = 0;
while (value < base) {
result++;
value /= base;
}
return result;
}
bool string_is_empty(const char *str) {
return str[0] == '\0';
}
// TODO: DOCUMENT ME!!!!!
// returns false if `src_str` could not be fully copied over to `dst_str`
bool string_copy(uint32_t dst_max_size, char* dst_str, const char* src_str) {
// This function guarantess that `dst_str` will be null-terminated,
for (uint32_t i = 0; i < dst_max_size; i++) {
bool is_last = i + 1 == dst_max_size;
if (is_last && src_str[i] != '\0') {
dst_str[i] = '\0';
return false;
}
if (src_str[i] == '\0') {
dst_str[i] = '\0';
return true;
}
dst_str[i] = src_str[i];
}
__builtin_unreachable();
}
void irrt_panic() { void irrt_panic() {
// Crash the program for now. // Crash the program for now.
// TODO: Don't crash the program // TODO: Don't crash the program

View File

@ -1768,357 +1768,163 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx,
{ {
} }
#[derive(Debug, Clone, Copy)] // #[derive(Debug, Clone, Copy)]
pub struct StructField<'ctx> { // pub struct StructField<'ctx> {
/// The GEP index of this struct field. // /// The GEP index of this struct field.
pub gep_index: u32, // pub gep_index: u32,
/// Name of this struct field. // /// Name of this struct field.
/// // ///
/// Used for generating names. // /// Used for generating names.
pub name: &'static str, // pub name: &'static str,
/// The type of this struct field. // /// The type of this struct field.
pub ty: BasicTypeEnum<'ctx>, // pub ty: BasicTypeEnum<'ctx>,
} // }
//
// pub struct StructFields<'ctx> {
// /// Name of the struct.
// ///
// /// Used for generating names.
// pub name: &'static str,
//
// /// All the [`StructField`]s of this struct.
// ///
// /// **NOTE:** The index position of a [`StructField`]
// /// matches the element's [`StructField::index`].
// pub fields: Vec<StructField<'ctx>>,
// }
//
// pub struct StructFieldsBuilder<'ctx> {
// gep_index_counter: u32,
// /// Name of the struct to be built.
// name: &'static str,
// fields: Vec<StructField<'ctx>>,
// }
//
// impl<'ctx> StructField<'ctx> {
// /// TODO: DOCUMENT ME
// pub fn gep(
// &self,
// ctx: &CodeGenContext<'ctx, '_>,
// struct_ptr: PointerValue<'ctx>,
// ) -> PointerValue<'ctx> {
// let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to use i32 for GEP like that
// unsafe {
// ctx.builder
// .build_in_bounds_gep(
// struct_ptr,
// &[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)],
// self.name,
// )
// .unwrap()
// }
// }
//
// /// TODO: DOCUMENT ME
// pub fn load(
// &self,
// ctx: &CodeGenContext<'ctx, '_>,
// struct_ptr: PointerValue<'ctx>,
// ) -> BasicValueEnum<'ctx> {
// ctx.builder.build_load(self.gep(ctx, struct_ptr), self.name).unwrap()
// }
//
// /// TODO: DOCUMENT ME
// pub fn store<V>(&self, ctx: &CodeGenContext<'ctx, '_>, struct_ptr: PointerValue<'ctx>, value: V)
// where
// V: BasicValue<'ctx>,
// {
// ctx.builder.build_store(self.gep(ctx, struct_ptr), value).unwrap();
// }
// }
pub struct StructFields<'ctx> { // type IsInstanceError = String;
/// Name of the struct. // type IsInstanceResult = Result<(), IsInstanceError>;
///
/// Used for generating names.
pub name: &'static str,
/// All the [`StructField`]s of this struct. // pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> IsInstanceResult
/// // where
/// **NOTE:** The index position of a [`StructField`] // A: BasicType<'ctx>,
/// matches the element's [`StructField::index`]. // B: BasicType<'ctx>,
pub fields: Vec<StructField<'ctx>>, // {
} // let expected = expected.as_basic_type_enum();
// let got = got.as_basic_type_enum();
pub struct StructFieldsBuilder<'ctx> { // // Put those logic into here,
gep_index_counter: u32, // // otherwise there is always a fallback reporting on any kind of mismatch
/// Name of the struct to be built. // match (expected, got) {
name: &'static str, // (BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => {
fields: Vec<StructField<'ctx>>, // if expected.get_bit_width() != got.get_bit_width() {
} // return Err(format!(
// "Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))"
// ));
// }
// }
// (expected, got) => {
// if expected != got {
// return Err(format!("Expected {expected}, got {got}"));
// }
// }
// }
// Ok(())
// }
impl<'ctx> StructField<'ctx> { // impl<'ctx> StructFields<'ctx> {
/// TODO: DOCUMENT ME // pub fn num_fields(&self) -> u32 {
pub fn gep( // self.fields.len() as u32
&self, // }
ctx: &CodeGenContext<'ctx, '_>, //
struct_ptr: PointerValue<'ctx>, // pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
) -> PointerValue<'ctx> { // let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec();
let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to use i32 for GEP like that // ctx.struct_type(llvm_fields.as_slice(), false)
unsafe { // }
ctx.builder //
.build_in_bounds_gep( // pub fn is_type(&self, scrutinee: StructType<'ctx>) -> IsInstanceResult {
struct_ptr, // // Check scrutinee's number of struct fields
&[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)], // if scrutinee.count_fields() != self.num_fields() {
self.name, // return Err(format!(
) // "Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}",
.unwrap() // struct_name = self.name,
} // expected_count = self.num_fields(),
} // got_count = scrutinee.count_fields(),
// ));
/// TODO: DOCUMENT ME // }
pub fn load( //
&self, // // Check the scrutinee's field types
ctx: &CodeGenContext<'ctx, '_>, // for field in self.fields.iter() {
struct_ptr: PointerValue<'ctx>, // let expected_field_ty = field.ty;
) -> BasicValueEnum<'ctx> { // let got_field_ty = scrutinee.get_field_type_at_index(field.gep_index).unwrap();
ctx.builder.build_load(self.gep(ctx, struct_ptr), self.name).unwrap() //
} // if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) {
// return Err(format!(
/// TODO: DOCUMENT ME // "Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}",
pub fn store<V>(&self, ctx: &CodeGenContext<'ctx, '_>, struct_ptr: PointerValue<'ctx>, value: V) // gep_index = field.gep_index,
where // struct_name = self.name,
V: BasicValue<'ctx>, // field_name = field.name,
{ // ));
ctx.builder.build_store(self.gep(ctx, struct_ptr), value).unwrap(); // }
} // }
} //
// // Done
type IsInstanceError = String; // Ok(())
type IsInstanceResult = Result<(), IsInstanceError>; // }
// }
pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> IsInstanceResult //
where // impl<'ctx> StructFieldsBuilder<'ctx> {
A: BasicType<'ctx>, // pub fn start(name: &'static str) -> Self {
B: BasicType<'ctx>, // StructFieldsBuilder { gep_index_counter: 0, name, fields: Vec::new() }
{ // }
let expected = expected.as_basic_type_enum(); //
let got = got.as_basic_type_enum(); // pub fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> {
// let index = self.gep_index_counter;
// Put those logic into here, // self.gep_index_counter += 1;
// otherwise there is always a fallback reporting on any kind of mismatch //
match (expected, got) { // let field = StructField { gep_index: index, name, ty };
(BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => { // self.fields.push(field); // Register into self.fields
if expected.get_bit_width() != got.get_bit_width() { //
return Err(format!( // field // Return to the caller to conveniently let them do whatever they want
"Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))" // }
)); //
} // pub fn end(self) -> StructFields<'ctx> {
} // StructFields { name: self.name, fields: self.fields }
(expected, got) => { // }
if expected != got { // }
return Err(format!("Expected {expected}, got {got}")); //
}
}
}
Ok(())
}
impl<'ctx> StructFields<'ctx> {
pub fn num_fields(&self) -> u32 {
self.fields.len() as u32
}
pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec();
ctx.struct_type(llvm_fields.as_slice(), false)
}
pub fn is_type(&self, scrutinee: StructType<'ctx>) -> IsInstanceResult {
// Check scrutinee's number of struct fields
if scrutinee.count_fields() != self.num_fields() {
return Err(format!(
"Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}",
struct_name = self.name,
expected_count = self.num_fields(),
got_count = scrutinee.count_fields(),
));
}
// Check the scrutinee's field types
for field in self.fields.iter() {
let expected_field_ty = field.ty;
let got_field_ty = scrutinee.get_field_type_at_index(field.gep_index).unwrap();
if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) {
return Err(format!(
"Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}",
gep_index = field.gep_index,
struct_name = self.name,
field_name = field.name,
));
}
}
// Done
Ok(())
}
}
impl<'ctx> StructFieldsBuilder<'ctx> {
pub fn start(name: &'static str) -> Self {
StructFieldsBuilder { gep_index_counter: 0, name, fields: Vec::new() }
}
pub fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> {
let index = self.gep_index_counter;
self.gep_index_counter += 1;
let field = StructField { gep_index: index, name, ty };
self.fields.push(field); // Register into self.fields
field // Return to the caller to conveniently let them do whatever they want
}
pub fn end(self) -> StructFields<'ctx> {
StructFields { name: self.name, fields: self.fields }
}
}
// TODO: Use derppening's abstraction
#[derive(Debug, Clone, Copy)]
pub struct NpArrayType<'ctx> {
pub size_type: IntType<'ctx>,
pub elem_type: BasicTypeEnum<'ctx>,
}
pub struct NpArrayStructFields<'ctx> {
pub whole_struct: StructFields<'ctx>,
pub data: StructField<'ctx>,
pub itemsize: StructField<'ctx>,
pub ndims: StructField<'ctx>,
pub shape: StructField<'ctx>,
pub strides: StructField<'ctx>,
}
impl<'ctx> NpArrayType<'ctx> {
pub fn new_opaque_elem(
ctx: &CodeGenContext<'ctx, '_>,
size_type: IntType<'ctx>,
) -> NpArrayType<'ctx> {
NpArrayType { size_type, elem_type: ctx.ctx.i8_type().as_basic_type_enum() }
}
pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
self.fields(ctx).whole_struct.get_struct_type(ctx)
}
pub fn fields(&self, ctx: &'ctx Context) -> NpArrayStructFields<'ctx> {
let mut builder = StructFieldsBuilder::start("NpArray");
let addrspace = AddressSpace::default();
let byte_type = ctx.i8_type();
// Make sure the struct matches PERFECTLY with that defined in `nac3core/irrt`.
let data = builder.add_field("data", byte_type.ptr_type(addrspace).into());
let itemsize = builder.add_field("itemsize", self.size_type.into());
let ndims = builder.add_field("ndims", self.size_type.into());
let shape = builder.add_field("shape", self.size_type.ptr_type(addrspace).into());
let strides = builder.add_field("strides", self.size_type.ptr_type(addrspace).into());
NpArrayStructFields { whole_struct: builder.end(), data, itemsize, ndims, shape, strides }
}
/// Allocate an `ndarray` on stack, with the following notes:
///
/// - `ndarray.ndims` will be initialized to `in_ndims`.
/// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`.
/// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`,
/// all with empty/uninitialized values.
pub fn alloca(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
in_ndims: IntValue<'ctx>,
name: &str,
) -> NpArrayValue<'ctx> {
let ptr = ctx
.builder
.build_alloca(self.get_struct_type(ctx.ctx).as_basic_type_enum(), name)
.unwrap();
// Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides`
let allocated_shape = ctx
.builder
.build_array_alloca(self.size_type.as_basic_type_enum(), in_ndims, "allocated_shape")
.unwrap();
let allocated_strides = ctx
.builder
.build_array_alloca(self.size_type.as_basic_type_enum(), in_ndims, "allocated_strides")
.unwrap();
let value = NpArrayValue { ty: *self, ptr };
value.store_ndims(ctx, in_ndims);
value.store_itemsize(ctx, self.elem_type.size_of().unwrap());
value.store_shape(ctx, allocated_shape);
value.store_strides(ctx, allocated_strides);
return value;
}
pub fn value_from_ptr(
&self,
ctx: &'ctx Context,
in_ndarray_ptr: PointerValue<'ctx>,
) -> NpArrayValue<'ctx> {
if cfg!(debug_assertions) {
// Sanity check on `in_ndarray_ptr`'s type
let in_ndarray_struct_type =
in_ndarray_ptr.get_type().get_element_type().into_struct_type();
// unwrap to check
self.fields(ctx).whole_struct.is_type(in_ndarray_struct_type).unwrap();
}
NpArrayValue { ty: *self, ptr: in_ndarray_ptr }
}
}
#[derive(Debug, Clone, Copy)]
pub struct NpArrayValue<'ctx> {
pub ty: NpArrayType<'ctx>,
pub ptr: PointerValue<'ctx>,
}
impl<'ctx> NpArrayValue<'ctx> {
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let field = self.ty.fields(ctx.ctx).data;
field.load(ctx, self.ptr).into_pointer_value()
}
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, new_data_ptr: PointerValue<'ctx>) {
let field = self.ty.fields(ctx.ctx).data;
field.store(ctx, self.ptr, new_data_ptr);
}
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let field = self.ty.fields(ctx.ctx).ndims;
field.load(ctx, self.ptr).into_int_value()
}
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, new_ndims: IntValue<'ctx>) {
let field = self.ty.fields(ctx.ctx).ndims;
field.store(ctx, self.ptr, new_ndims);
}
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let field = self.ty.fields(ctx.ctx).itemsize;
field.load(ctx, self.ptr).into_int_value()
}
pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, new_itemsize: IntValue<'ctx>) {
let field = self.ty.fields(ctx.ctx).itemsize;
field.store(ctx, self.ptr, new_itemsize);
}
pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let field = self.ty.fields(ctx.ctx).shape;
field.load(ctx, self.ptr).into_pointer_value()
}
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, new_shape_ptr: PointerValue<'ctx>) {
let field = self.ty.fields(ctx.ctx).shape;
field.store(ctx, self.ptr, new_shape_ptr);
}
pub fn load_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let field = self.ty.fields(ctx.ctx).strides;
field.load(ctx, self.ptr).into_pointer_value()
}
pub fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
let field = self.ty.fields(ctx.ctx).strides;
field.store(ctx, self.ptr, value);
}
/// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!!
pub fn shape_slice(
&self,
ctx: &CodeGenContext<'ctx, '_>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// Get the pointer to `shape`
let field = self.ty.fields(ctx.ctx).shape;
let shape = field.load(ctx, self.ptr).into_pointer_value();
// Load `ndims`
let ndims = self.load_ndims(ctx);
TypedArrayLikeAdapter {
adapted: ArraySliceValue(shape, ndims, Some(field.name)),
downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
}
}
/// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!!
pub fn strides_slice(
&self,
ctx: &CodeGenContext<'ctx, '_>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
// Get the pointer to `strides`
let field = self.ty.fields(ctx.ctx).strides;
let strides = field.load(ctx, self.ptr).into_pointer_value();
// Load `ndims`
let ndims = self.load_ndims(ctx);
TypedArrayLikeAdapter {
adapted: ArraySliceValue(strides, ndims, Some(field.name)),
downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
}
}
}

View File

@ -2202,6 +2202,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let dst_ndims = deduce_ndims_after_slicing(ndims, ndslices.iter()); let dst_ndims = deduce_ndims_after_slicing(ndims, ndslices.iter());
// Finally, perform the actual subscript logic // Finally, perform the actual subscript logic
// TODO: `call_ndarray_subscript_impl` under the hood deduces `dst_ndims` again. We could save it some time by passing `dst_ndims` - a TODO?
let subndarray = call_ndarray_subscript_impl( let subndarray = call_ndarray_subscript_impl(
generator, generator,
ctx, ctx,

View File

@ -0,0 +1,87 @@
// TODO: Use derppening's abstraction
use std::marker::PhantomData;
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, IntType},
values::BasicValueEnum,
AddressSpace,
};
use crate::codegen::structure::{
CustomStructType, CustomType, Field, FieldCreator, IntType2, Object, PointerType2,
PointingArrayType,
};
#[derive(Debug, Clone, Copy)]
pub struct NpArrayType<'ctx> {
pub size_type: IntType<'ctx>,
pub elem_type: BasicTypeEnum<'ctx>,
}
pub struct NpArrayFields<'ctx> {
pub data: Field<'ctx, PointerType2<'ctx>>,
pub itemsize: Field<'ctx, IntType2<'ctx>>,
pub ndims: Field<'ctx, IntType2<'ctx>>,
pub shape: Field<'ctx, PointingArrayType<'ctx, IntType2<'ctx>>>,
pub strides: Field<'ctx, PointingArrayType<'ctx, IntType2<'ctx>>>,
}
pub type NpArrayValue<'ctx> = Object<'ctx, NpArrayType<'ctx>>;
// impl<'ctx> CustomType<'ctx> for NpArrayType<'ctx> {
// type Value = NpArrayValue<'ctx>;
//
// fn llvm_basic_type_enum(
// &self,
// ctx: &'ctx inkwell::context::Context,
// ) -> inkwell::types::BasicTypeEnum<'ctx> {
// self.llvm_struct_type(ctx).as_basic_type_enum()
// }
//
// fn llvm_field_load(
// &self,
// ctx: &crate::codegen::CodeGenContext<'ctx, '_>,
// field: crate::codegen::structure::FieldInfo,
// struct_ptr: inkwell::values::PointerValue<'ctx>,
// ) -> Self::Value {
// let ok = field.llvm_load(ctx, struct_ptr);
// todo!()
// }
//
// fn llvm_field_store(
// &self,
// ctx: &crate::codegen::CodeGenContext<'ctx, '_>,
// field: crate::codegen::structure::FieldInfo,
// struct_ptr: inkwell::values::PointerValue<'ctx>,
// value: &Self::Value,
// ) {
// todo!()
// }
// }
impl<'ctx> CustomStructType<'ctx> for NpArrayType<'ctx> {
type Fields = NpArrayFields<'ctx>;
fn llvm_struct_name() -> &'static str {
"NDArray"
}
fn add_fields_to(&self, creator: &mut FieldCreator<'ctx>) -> Self::Fields {
let pi8 = creator.ctx.i8_type().ptr_type(AddressSpace::default());
NpArrayFields {
data: creator.add_field("data", PointerType2(pi8)),
itemsize: creator.add_field("itemsize", IntType2(self.size_type)),
ndims: creator.add_field("ndims", IntType2(self.size_type)),
shape: creator.add_field("shape", PointingArrayType::new(IntType2(self.size_type))),
strides: creator.add_field("strides", PointingArrayType::new(IntType2(self.size_type))),
}
}
}
impl<'ctx> NpArrayType<'ctx> {
pub fn new_opaque_elem(ctx: &'ctx Context, size_type: IntType<'ctx>) -> Self {
NpArrayType { elem_type: ctx.i8_type().into(), size_type }
}
}

View File

@ -11,13 +11,15 @@ mod test;
use super::{ use super::{
classes::{ classes::{
check_basic_types_match, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, check_basic_types_match, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue,
NDArrayValue, NpArrayType, NpArrayValue, StructField, StructFields, TypedArrayLikeAdapter, NDArrayValue, StructField, StructFields, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
UntypedArrayLikeAccessor,
}, },
llvm_intrinsics, CodeGenContext, CodeGenerator, llvm_intrinsics,
structure::CustomStructType,
CodeGenContext, CodeGenerator,
}; };
use crate::codegen::classes::TypedArrayLikeAccessor; use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::stmt::gen_for_callback_incrementing;
use classes::{NpArrayType, NpArrayValue};
use crossbeam::channel::IntoIter; use crossbeam::channel::IntoIter;
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
@ -33,6 +35,7 @@ use inkwell::{
}; };
use itertools::Either; use itertools::Either;
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
pub mod classes;
#[must_use] #[must_use]
pub fn load_irrt(ctx: &Context) -> Module { pub fn load_irrt(ctx: &Context) -> Module {
@ -957,7 +960,7 @@ pub struct UserSlice<'ctx> {
pub step: Option<IntValue<'ctx>>, pub step: Option<IntValue<'ctx>>,
} }
pub struct IrrtUserSliceStructFields<'ctx> { pub struct IrrtUserSliceTypeStructFields<'ctx> {
pub whole_struct: StructFields<'ctx>, pub whole_struct: StructFields<'ctx>,
pub start_defined: StructField<'ctx>, pub start_defined: StructField<'ctx>,
@ -971,10 +974,10 @@ pub struct IrrtUserSliceStructFields<'ctx> {
} }
// TODO: EMPTY STRUCT // TODO: EMPTY STRUCT
struct IrrtUserSlice {} struct IrrtUserSliceType();
impl IrrtUserSlice { impl IrrtUserSliceType {
pub fn fields<'ctx>(ctx: &'ctx Context) -> IrrtUserSliceStructFields<'ctx> { pub fn fields<'ctx>(&self, ctx: &'ctx Context) -> IrrtUserSliceTypeStructFields<'ctx> {
let int8 = ctx.i8_type(); let int8 = ctx.i8_type();
// MUST match the corresponding struct defined in IRRT // MUST match the corresponding struct defined in IRRT
@ -986,7 +989,7 @@ impl IrrtUserSlice {
let step_defined = builder.add_field("step_defined", int8.into()); let step_defined = builder.add_field("step_defined", int8.into());
let step = builder.add_field("step", get_sliceindex_type(ctx).into()); let step = builder.add_field("step", get_sliceindex_type(ctx).into());
IrrtUserSliceStructFields { IrrtUserSliceTypeStructFields {
start_defined, start_defined,
start, start,
stop_defined, stop_defined,
@ -998,11 +1001,12 @@ impl IrrtUserSlice {
} }
pub fn alloca_user_slice<'ctx>( pub fn alloca_user_slice<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
user_slice: &UserSlice<'ctx>, user_slice: &UserSlice<'ctx>,
) -> PointerValue<'ctx> { ) -> PointerValue<'ctx> {
// Derive the struct_type // Derive the struct_type
let fields = Self::fields(ctx.ctx); let fields = self.fields(ctx.ctx);
let struct_type = fields.whole_struct.get_struct_type(ctx.ctx); let struct_type = fields.whole_struct.get_struct_type(ctx.ctx);
// ...and then allocate for a real `UserSlice` in LLVM // ...and then allocate for a real `UserSlice` in LLVM
@ -1070,23 +1074,22 @@ where
} }
// TODO: Empty struct // TODO: Empty struct
pub struct IrrtNDSlice {} pub struct IrrtNDSliceType();
pub struct IrrtNDSliceTypeStructFields<'ctx> {
pub struct IrrtNDSliceStructFields<'ctx> {
pub whole_struct: StructFields<'ctx>, pub whole_struct: StructFields<'ctx>,
pub type_: StructField<'ctx>, pub type_: StructField<'ctx>,
pub slice: StructField<'ctx>, pub slice: StructField<'ctx>,
} }
impl IrrtNDSlice { impl IrrtNDSliceType {
pub fn fields<'ctx>(ctx: &'ctx Context) -> IrrtNDSliceStructFields<'ctx> { pub fn fields<'ctx>(ctx: &'ctx Context) -> IrrtNDSliceTypeStructFields<'ctx> {
let mut builder = StructFieldsBuilder::start("NDSlice"); let mut builder = StructFieldsBuilder::start("NDSlice");
// MUST match the corresponding struct defined in IRRT // MUST match the corresponding struct defined in IRRT
let type_ = builder.add_field("type", get_ndslicetype_constant_type(ctx).into()); let type_ = builder.add_field("type", get_ndslicetype_constant_type(ctx).into());
let slice = builder.add_field("slice", get_opaque_uint8_ptr_type(ctx).into()); let slice = builder.add_field("slice", get_opaque_uint8_ptr_type(ctx).into());
IrrtNDSliceStructFields { type_, slice, whole_struct: builder.end() } IrrtNDSliceTypeStructFields { type_, slice, whole_struct: builder.end() }
} }
pub fn alloca_ndslices<'ctx>( pub fn alloca_ndslices<'ctx>(
@ -1124,7 +1127,7 @@ impl IrrtNDSlice {
} }
NDSlice::Slice(user_slice) => { NDSlice::Slice(user_slice) => {
// Allocate the user_slice // Allocate the user_slice
let slice_ptr = IrrtUserSlice::alloca_user_slice(ctx, user_slice); let slice_ptr = IrrtUserSliceType().alloca_user_slice(ctx, user_slice);
let type_ = 1; // const NDSliceType INPUT_SLICE_TYPE_SLICE = 1; let type_ = 1; // const NDSliceType INPUT_SLICE_TYPE_SLICE = 1;
(type_, slice_ptr) (type_, slice_ptr)
@ -1156,6 +1159,84 @@ impl IrrtNDSlice {
} }
} }
struct IrrtPrinterType();
struct IrrtPrinterTypeStructFields<'ctx> {
whole_struct: StructFields<'ctx>,
string_base_ptr: StructField<'ctx>,
max_length: StructField<'ctx>,
length: StructField<'ctx>,
}
impl IrrtPrinterType {
pub fn fields<'ctx>(&self, ctx: &'ctx Context) -> IrrtPrinterTypeStructFields<'ctx> {
let mut builder = StructFieldsBuilder::start("Printer");
let string_base_ptr = builder
.add_field("string_base_ptr", ctx.i8_type().ptr_type(AddressSpace::default()).into());
let max_length = builder.add_field("max_length", ctx.i32_type().into());
let length = builder.add_field("length", ctx.i32_type().into());
IrrtPrinterTypeStructFields {
string_base_ptr,
max_length,
length,
whole_struct: builder.end(),
}
}
}
struct IrrtPrinterValue<'ctx> {
ty: IrrtPrinterType,
ptr: PointerValue<'ctx>,
}
impl<'ctx> IrrtPrinterValue<'ctx> {
pub fn hl(&self) {}
}
struct IrrtErrorContextType();
struct IrrtErrorContextTypeStructFields<'ctx> {
whole_struct: StructFields<'ctx>,
error: StructField<'ctx>,
}
impl IrrtErrorContextType {
pub fn fields<'ctx>(&self, ctx: &'ctx Context) -> IrrtErrorContextTypeStructFields<'ctx> {
let mut builder = StructFieldsBuilder::start("ErrorContext");
let error = builder.add_field(
"error",
IrrtPrinterType().fields(ctx).whole_struct.get_struct_type(ctx).into(),
);
IrrtErrorContextTypeStructFields { error, whole_struct: builder.end() }
}
pub fn alloca<'ctx>(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let fields = self.fields(ctx.ctx);
let struct_type = fields.whole_struct.get_struct_type(ctx.ctx);
ctx.builder.build_alloca(struct_type, "error_context").unwrap()
}
pub fn load_error<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
struct_ptr: PointerValue<'ctx>,
) -> IrrtPrinterValue<'ctx> {
let error = self.fields(ctx.ctx).error;
IrrtPrinterValue {
ty: IrrtPrinterType(),
ptr: error.load(ctx, struct_ptr).into_pointer_value(),
}
}
}
// struct IrrtErrorContextValue<'ctx> {
// ty: IrrtErrorContextValue
// }
fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant { fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant {
match ty.get_bit_width() { match ty.get_bit_width() {
32 => SizeVariant::Bits32, 32 => SizeVariant::Bits32,
@ -1198,8 +1279,7 @@ pub fn get_irrt_ndarray_ptr_type<'ctx>(
let i8_type = ctx.i8_type(); let i8_type = ctx.i8_type();
let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() }; let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() };
let struct_ty = ndarray_ty.get_struct_type(ctx); ndarray_ty.llvm_struct_type(ctx).ptr_type(AddressSpace::default())
struct_ty.ptr_type(AddressSpace::default())
} }
pub fn get_opaque_uint8_ptr_type<'ctx>(ctx: &'ctx Context) -> PointerType<'ctx> { pub fn get_opaque_uint8_ptr_type<'ctx>(ctx: &'ctx Context) -> PointerType<'ctx> {
@ -1317,7 +1397,7 @@ pub fn call_nac3_ndarray_deduce_ndims_after_slicing<'ctx>(
&[ &[
size_type.into(), // SizeT ndims size_type.into(), // SizeT ndims
size_type.into(), // SizeT num_slices size_type.into(), // SizeT num_slices
IrrtNDSlice::fields(ctx.ctx) IrrtNDSliceType::fields(ctx.ctx)
.whole_struct .whole_struct
.get_struct_type(ctx.ctx) .get_struct_type(ctx.ctx)
.ptr_type(AddressSpace::default()) .ptr_type(AddressSpace::default())
@ -1360,7 +1440,7 @@ pub fn call_nac3_ndarray_subscript<'ctx>(
&[ &[
get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray<SizeT>* ndarray get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray<SizeT>* ndarray
size_type.into(), // SizeT num_slices size_type.into(), // SizeT num_slices
IrrtNDSlice::fields(ctx.ctx) IrrtNDSliceType::fields(ctx.ctx)
.whole_struct .whole_struct
.get_struct_type(ctx.ctx) .get_struct_type(ctx.ctx)
.ptr_type(AddressSpace::default()) .ptr_type(AddressSpace::default())
@ -1393,19 +1473,14 @@ pub fn call_nac3_len<'ctx>(
let size_type = ndarray.ty.size_type; let size_type = ndarray.ty.size_type;
// Get the IRRT function // Get the IRRT function
let function = get_size_type_dependent_function( let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_len", || {
ctx,
size_type,
"__nac3_ndarray_len",
|| {
get_sliceindex_type(ctx.ctx).fn_type( get_sliceindex_type(ctx.ctx).fn_type(
&[ &[
get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray<SizeT> *ndarray get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray<SizeT> *ndarray
], ],
false, false,
) )
}, });
);
// Call the IRRT function // Call the IRRT function
ctx.builder ctx.builder

View File

@ -44,6 +44,7 @@ pub mod irrt;
pub mod llvm_intrinsics; pub mod llvm_intrinsics;
pub mod numpy; pub mod numpy;
pub mod stmt; pub mod stmt;
pub mod structure;
#[cfg(test)] #[cfg(test)]
mod test; mod test;

View File

@ -38,7 +38,7 @@ use super::{
irrt::{ irrt::{
call_nac3_ndarray_deduce_ndims_after_slicing, call_nac3_ndarray_set_strides_by_shape, call_nac3_ndarray_deduce_ndims_after_slicing, call_nac3_ndarray_set_strides_by_shape,
call_nac3_ndarray_size, call_nac3_ndarray_subscript, get_irrt_ndarray_ptr_type, call_nac3_ndarray_size, call_nac3_ndarray_subscript, get_irrt_ndarray_ptr_type,
get_opaque_uint8_ptr_type, IrrtNDSlice, NDSlice, get_opaque_uint8_ptr_type, IrrtNDSliceType, NDSlice,
}, },
stmt::gen_return, stmt::gen_return,
}; };
@ -2364,7 +2364,7 @@ where
let num_slices = size_type.const_int(ndslices.len() as u64, false); let num_slices = size_type.const_int(ndslices.len() as u64, false);
// Prepare the argument `slices` // Prepare the argument `slices`
let ndslices_ptr = IrrtNDSlice::alloca_ndslices(ctx, ndslices); let ndslices_ptr = IrrtNDSliceType::alloca_ndslices(ctx, ndslices);
// Get `dst_ndims` // Get `dst_ndims`
let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing( let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing(

View File

@ -0,0 +1,318 @@
use std::marker::PhantomData;
use inkwell::{
context::Context,
types::{BasicType, BasicTypeEnum, IntType, PointerType, StructType},
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
AddressSpace,
};
use super::CodeGenContext;
#[derive(Debug, Clone, Copy)]
pub struct FieldInfo {
gep_index: u32,
name: &'static str,
}
impl FieldInfo {
pub fn llvm_gep<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
struct_ptr: PointerValue<'ctx>,
) -> PointerValue<'ctx> {
let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to *just* use i32 for GEP like that
unsafe {
ctx.builder
.build_in_bounds_gep(
struct_ptr,
&[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)],
self.name,
)
.unwrap()
}
}
pub fn llvm_load<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
struct_ptr: PointerValue<'ctx>,
) -> BasicValueEnum<'ctx> {
// We will use `self.name` as the LLVM label for debugging purposes
ctx.builder.build_load(self.llvm_gep(ctx, struct_ptr), self.name).unwrap()
}
pub fn llvm_store<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
struct_ptr: PointerValue<'ctx>,
value: BasicValueEnum<'ctx>,
) {
ctx.builder.build_store(self.llvm_gep(ctx, struct_ptr), value).unwrap();
}
}
pub struct Object<'ctx, T> {
pub ty: T,
pub ptr: PointerValue<'ctx>,
}
pub struct Field<'ctx, T: CustomType<'ctx>> {
pub info: FieldInfo,
pub ty: T,
_phantom: PhantomData<&'ctx ()>,
}
pub struct FieldCreator<'ctx> {
pub ctx: &'ctx Context,
struct_name: &'ctx str,
gep_index_counter: u32,
fields: Vec<(FieldInfo, BasicTypeEnum<'ctx>)>,
}
impl<'ctx> FieldCreator<'ctx> {
pub fn new(ctx: &'ctx Context, struct_name: &'ctx str) -> Self {
FieldCreator { ctx, struct_name, gep_index_counter: 0, fields: Vec::new() }
}
fn next_gep_index(&mut self) -> u32 {
let index = self.gep_index_counter;
self.gep_index_counter += 1;
index
}
fn get_struct_field_types(&self) -> Vec<BasicTypeEnum<'ctx>> {
self.fields.iter().map(|x| x.1.clone()).collect()
}
pub fn add_field<T: CustomType<'ctx>>(&mut self, name: &'static str, ty: T) -> Field<'ctx, T> {
let gep_index = self.next_gep_index();
let field_type = ty.llvm_basic_type_enum(self.ctx);
let field_info = FieldInfo { gep_index, name };
let field = Field { info: field_info, ty, _phantom: PhantomData };
self.fields.push((field_info.clone(), field_type));
field
}
fn num_fields(&self) -> u32 {
self.fields.len() as u32 // casted to u32 because that is what inkwell returns
}
}
pub trait CustomType<'ctx>: Clone {
type Value;
fn llvm_basic_type_enum(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
fn llvm_field_load(
&self,
ctx: &CodeGenContext<'ctx, '_>,
field: FieldInfo,
struct_ptr: PointerValue<'ctx>,
) -> Self::Value;
fn llvm_field_store(
&self,
ctx: &CodeGenContext<'ctx, '_>,
field: FieldInfo,
struct_ptr: PointerValue<'ctx>,
value: &Self::Value,
);
}
#[derive(Debug, Clone, Copy)]
pub struct IntType2<'ctx>(pub IntType<'ctx>);
impl<'ctx> CustomType<'ctx> for IntType2<'ctx> {
type Value = IntValue<'ctx>;
fn llvm_basic_type_enum(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
self.0.as_basic_type_enum()
}
fn llvm_field_load(
&self,
ctx: &CodeGenContext<'ctx, '_>,
field: FieldInfo,
struct_ptr: PointerValue<'ctx>,
) -> Self::Value {
let int_value = field.llvm_load(ctx, struct_ptr).into_int_value();
assert_eq!(int_value.get_type().get_bit_width(), self.0.get_bit_width());
int_value
}
fn llvm_field_store(
&self,
ctx: &CodeGenContext<'ctx, '_>,
field: FieldInfo,
struct_ptr: PointerValue<'ctx>,
int_value: &Self::Value,
) {
assert_eq!(int_value.get_type().get_bit_width(), self.0.get_bit_width());
field.llvm_store(ctx, struct_ptr, int_value.as_basic_value_enum());
}
}
#[derive(Debug, Clone, Copy)]
pub struct PointerType2<'ctx>(pub PointerType<'ctx>);
impl<'ctx> CustomType<'ctx> for PointerType2<'ctx> {
type Value = PointerValue<'ctx>;
fn llvm_basic_type_enum(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
self.0.as_basic_type_enum()
}
fn llvm_field_load(
&self,
ctx: &CodeGenContext<'ctx, '_>,
field: FieldInfo,
struct_ptr: PointerValue<'ctx>,
) -> Self::Value {
field.llvm_load(ctx, struct_ptr).into_pointer_value()
}
fn llvm_field_store(
&self,
ctx: &CodeGenContext<'ctx, '_>,
field: FieldInfo,
struct_ptr: PointerValue<'ctx>,
pointer_value: &Self::Value,
) {
field.llvm_store(ctx, struct_ptr, pointer_value.as_basic_value_enum());
}
}
#[derive(Debug, Clone, Copy)]
pub struct PointingArrayType<'ctx, ElementType: CustomType<'ctx>> {
pub element_type: ElementType,
_phantom: PhantomData<&'ctx ()>,
}
impl<'ctx, ElementType: CustomType<'ctx>> PointingArrayType<'ctx, ElementType> {
pub fn new(element_type: ElementType) -> Self {
PointingArrayType { element_type, _phantom: PhantomData }
}
}
impl<'ctx, Element: CustomType<'ctx>> CustomType<'ctx> for PointingArrayType<'ctx, Element> {
type Value = Object<'ctx, Self>;
fn llvm_basic_type_enum(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
// Element*
self.element_type
.llvm_basic_type_enum(ctx)
.ptr_type(AddressSpace::default())
.as_basic_type_enum()
}
fn llvm_field_load(
&self,
ctx: &CodeGenContext<'ctx, '_>,
field: FieldInfo,
struct_ptr: PointerValue<'ctx>,
) -> Self::Value {
// Remember that it is just a pointer
Object { ty: self.clone(), ptr: field.llvm_load(ctx, struct_ptr).into_pointer_value() }
}
fn llvm_field_store(
&self,
ctx: &CodeGenContext<'ctx, '_>,
field: FieldInfo,
struct_ptr: PointerValue<'ctx>,
value: &Self::Value,
) {
// Remember that it is just a pointer
todo!()
}
}
pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> Result<(), String>
where
A: BasicType<'ctx>,
B: BasicType<'ctx>,
{
let expected = expected.as_basic_type_enum();
let got = got.as_basic_type_enum();
// Put those logic into here,
// otherwise there is always a fallback reporting on any kind of mismatch
match (expected, got) {
(BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => {
if expected.get_bit_width() != got.get_bit_width() {
return Err(format!(
"Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))"
));
}
}
(expected, got) => {
if expected != got {
return Err(format!("Expected {expected}, got {got}"));
}
}
}
Ok(())
}
pub trait CustomStructType<'ctx> {
type Fields;
fn llvm_struct_name() -> &'static str;
fn add_fields_to(&self, creator: &mut FieldCreator<'ctx>) -> Self::Fields;
fn fields(&self, ctx: &'ctx Context) -> Self::Fields {
let mut creator = FieldCreator::new(ctx, Self::llvm_struct_name());
let fields = self.add_fields_to(&mut creator);
fields
}
fn llvm_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
let mut creator = FieldCreator::new(ctx, Self::llvm_struct_name());
self.add_fields_to(&mut creator);
ctx.struct_type(&creator.get_struct_field_types(), false)
}
fn check_struct_type(
&self,
ctx: &'ctx Context,
scrutinee: StructType<'ctx>,
) -> Result<(), String> {
let mut creator = FieldCreator::new(ctx, Self::llvm_struct_name());
self.add_fields_to(&mut creator);
// Check scrutinee's number of struct fields
let expected_field_count = creator.num_fields();
let got_field_count = scrutinee.count_fields();
if got_field_count != expected_field_count {
return Err(format!(
"Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}",
struct_name = Self::llvm_struct_name(),
expected_count = expected_field_count,
got_count = got_field_count,
));
}
// Check the scrutinee's field types
for (field_info, expected_field_ty) in creator.fields {
let got_field_ty = scrutinee.get_field_type_at_index(field_info.gep_index).unwrap();
if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) {
return Err(format!(
"Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}",
gep_index = field_info.gep_index,
struct_name = Self::llvm_struct_name(),
field_name = field_info.name,
));
}
}
// Done
Ok(())
}
}

View File

@ -1,6 +1,7 @@
use std::iter::once; use std::iter::once;
use crate::{codegen::classes::NpArrayType, util::SizeVariant}; use crate::util::SizeVariant;
use classes::NpArrayType;
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails}; use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
use indexmap::IndexMap; use indexmap::IndexMap;
use inkwell::{ use inkwell::{
@ -1466,10 +1467,13 @@ impl<'a> BuiltinBuilder<'a> {
let ndarray_ptr = arg.into_pointer_value(); // It has to be an ndarray let ndarray_ptr = arg.into_pointer_value(); // It has to be an ndarray
let size_type = generator.get_size_type(ctx.ctx); let size_type = generator.get_size_type(ctx.ctx);
let ndarray_ty = NpArrayType::new_opaque_elem(ctx, size_type); // We don't need to care about the element type - we only want the shape let ndarray_ty = NpArrayType::new_opaque_elem(ctx.ctx, size_type); // We don't need to care about the element type - we only want the shape
let ndarray = ndarray_ty.value_from_ptr(ctx.ctx, ndarray_ptr); let ndarray = ndarray_ty.value_from_ptr(ctx.ctx, ndarray_ptr);
Some(call_nac3_len(ctx, ndarray).as_basic_value_enum()) let result = call_nac3_len(ctx, ndarray).as_basic_value_enum();
Some(result)
// Some(.as_basic_value_enum())
// let llvm_i32 = ctx.ctx.i32_type(); // let llvm_i32 = ctx.ctx.i32_type();
// let llvm_usize = generator.get_size_type(ctx.ctx); // let llvm_usize = generator.get_size_type(ctx.ctx);