forked from M-Labs/nac3
WIP: core: save progress
This commit is contained in:
parent
ab7ff2ae9d
commit
3f4ee433f1
|
@ -0,0 +1,31 @@
|
|||
#pragma once
|
||||
|
||||
#include "irrt_printer.hpp"
|
||||
|
||||
namespace {
|
||||
#define MAX_ERROR_NAME_LEN 32
|
||||
|
||||
// TODO: right now just to report some messages for now
|
||||
struct ErrorContext {
|
||||
Printer error;
|
||||
// TODO: add error_class_name??
|
||||
|
||||
void initialize(char* string_base_ptr, uint32_t max_length) {
|
||||
error.initialize(string_base_ptr, max_length);
|
||||
}
|
||||
|
||||
bool has_error() {
|
||||
return error.length > 0;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
void __nac3_error_context_init(ErrorContext* ctx, char* string_base_ptr, uint32_t max_length) {
|
||||
ctx->initialize(string_base_ptr, max_length);
|
||||
}
|
||||
|
||||
uint8_t __nac3_error_context_has_error(ErrorContext* ctx) {
|
||||
return (uint8_t) ctx->has_error();
|
||||
}
|
||||
}
|
|
@ -1,10 +1,12 @@
|
|||
#pragma once
|
||||
|
||||
#include "irrt_utils.hpp"
|
||||
#include "irrt_typedefs.hpp"
|
||||
#include "irrt_basic.hpp"
|
||||
#include "irrt_slice.hpp"
|
||||
#include "irrt_error_context.hpp"
|
||||
#include "irrt_numpy_ndarray.hpp"
|
||||
#include "irrt_printer.hpp"
|
||||
#include "irrt_slice.hpp"
|
||||
#include "irrt_typedefs.hpp"
|
||||
#include "irrt_utils.hpp"
|
||||
|
||||
/*
|
||||
All IRRT implementations.
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
#pragma once
|
||||
|
||||
#include "irrt_typedefs.hpp"
|
||||
|
||||
// TODO: obviously implementing printf from scratch is bad,
|
||||
// is there a header only, no-cstdlib library for this?
|
||||
|
||||
namespace {
|
||||
struct Printer {
|
||||
char* string_base_ptr;
|
||||
uint32_t max_length;
|
||||
uint32_t length; // NOTE: this could be incremented past max_length, which indicates
|
||||
|
||||
void initialize(char *string_base_ptr, uint32_t max_length) {
|
||||
this->string_base_ptr = string_base_ptr;
|
||||
this->max_length = max_length;
|
||||
this->length = 0;
|
||||
}
|
||||
|
||||
void put_space() {
|
||||
put_char(' ');
|
||||
}
|
||||
|
||||
void put_char(char ch) {
|
||||
push_char(ch);
|
||||
}
|
||||
|
||||
void put_string(const char* string) {
|
||||
// TODO: optimize?
|
||||
while (*string != '\0') {
|
||||
push_char(*string);
|
||||
string++; // Move to next char
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void put_int(T value) {
|
||||
// NOTE: Try not to use recursion to print the digits
|
||||
|
||||
// value == 0 is a special case
|
||||
if (value == 0) {
|
||||
push_char('0');
|
||||
} else {
|
||||
// Add a '-' if the value is negative
|
||||
if (value < 0) {
|
||||
push_char('-');
|
||||
value = -value; // Negate then continue to print the digits
|
||||
}
|
||||
|
||||
// TODO: Recursion is a bad idea on embedded systems?
|
||||
uint32_t num_digits = int_log_floor(value, 10) + 1;
|
||||
put_int_helper(num_digits, value);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: implement put_float() and more would be useful
|
||||
private:
|
||||
void push_char(char ch) {
|
||||
if (length < max_length) {
|
||||
string_base_ptr[length] = ch;
|
||||
}
|
||||
|
||||
// NOTE: this could increment past max_length,
|
||||
// to indicate the true length of the message even if it gets cut off
|
||||
length++;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void put_int_helper(uint32_t num_digits, T value) {
|
||||
// Print the digits recursively
|
||||
__builtin_assume(0 <= value);
|
||||
|
||||
if (num_digits > 0) {
|
||||
put_int_helper(num_digits - 1, value / 10);
|
||||
|
||||
uint32_t digit = value % 10;
|
||||
char digit_char = '0' + (char) digit;
|
||||
put_char(digit_char);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
|
@ -675,6 +675,17 @@ void test_ndarray_broadcast_1() {
|
|||
assert_values_match("dst_ndarray[1, 2, 3]", "%f", 49.9, *((double*) dst_ndarray.get_pelement_by_indices((int32_t[]) {1, 2, 3})));
|
||||
}
|
||||
|
||||
void test_printer() {
|
||||
const uint32_t buffer_len = 256;
|
||||
char buffer[buffer_len];
|
||||
Printer printer = {
|
||||
.string_base_ptr = buffer,
|
||||
.max_length = buffer_len,
|
||||
.length = 0
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_calc_size_from_shape_normal();
|
||||
test_calc_size_from_shape_has_zero();
|
||||
|
@ -691,5 +702,6 @@ int main() {
|
|||
test_ndslice_3();
|
||||
test_can_broadcast_shape();
|
||||
test_ndarray_broadcast_1();
|
||||
test_printer();
|
||||
return 0;
|
||||
}
|
|
@ -21,6 +21,43 @@ namespace {
|
|||
return true;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
uint32_t int_log_floor(T value, T base) {
|
||||
uint32_t result = 0;
|
||||
while (value < base) {
|
||||
result++;
|
||||
value /= base;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool string_is_empty(const char *str) {
|
||||
return str[0] == '\0';
|
||||
}
|
||||
|
||||
// TODO: DOCUMENT ME!!!!!
|
||||
// returns false if `src_str` could not be fully copied over to `dst_str`
|
||||
bool string_copy(uint32_t dst_max_size, char* dst_str, const char* src_str) {
|
||||
// This function guarantess that `dst_str` will be null-terminated,
|
||||
|
||||
for (uint32_t i = 0; i < dst_max_size; i++) {
|
||||
bool is_last = i + 1 == dst_max_size;
|
||||
if (is_last && src_str[i] != '\0') {
|
||||
dst_str[i] = '\0';
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src_str[i] == '\0') {
|
||||
dst_str[i] = '\0';
|
||||
return true;
|
||||
}
|
||||
|
||||
dst_str[i] = src_str[i];
|
||||
}
|
||||
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
void irrt_panic() {
|
||||
// Crash the program for now.
|
||||
// TODO: Don't crash the program
|
||||
|
|
|
@ -1768,357 +1768,163 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx,
|
|||
{
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct StructField<'ctx> {
|
||||
/// The GEP index of this struct field.
|
||||
pub gep_index: u32,
|
||||
/// Name of this struct field.
|
||||
///
|
||||
/// Used for generating names.
|
||||
pub name: &'static str,
|
||||
/// The type of this struct field.
|
||||
pub ty: BasicTypeEnum<'ctx>,
|
||||
}
|
||||
// #[derive(Debug, Clone, Copy)]
|
||||
// pub struct StructField<'ctx> {
|
||||
// /// The GEP index of this struct field.
|
||||
// pub gep_index: u32,
|
||||
// /// Name of this struct field.
|
||||
// ///
|
||||
// /// Used for generating names.
|
||||
// pub name: &'static str,
|
||||
// /// The type of this struct field.
|
||||
// pub ty: BasicTypeEnum<'ctx>,
|
||||
// }
|
||||
//
|
||||
// pub struct StructFields<'ctx> {
|
||||
// /// Name of the struct.
|
||||
// ///
|
||||
// /// Used for generating names.
|
||||
// pub name: &'static str,
|
||||
//
|
||||
// /// All the [`StructField`]s of this struct.
|
||||
// ///
|
||||
// /// **NOTE:** The index position of a [`StructField`]
|
||||
// /// matches the element's [`StructField::index`].
|
||||
// pub fields: Vec<StructField<'ctx>>,
|
||||
// }
|
||||
//
|
||||
// pub struct StructFieldsBuilder<'ctx> {
|
||||
// gep_index_counter: u32,
|
||||
// /// Name of the struct to be built.
|
||||
// name: &'static str,
|
||||
// fields: Vec<StructField<'ctx>>,
|
||||
// }
|
||||
//
|
||||
// impl<'ctx> StructField<'ctx> {
|
||||
// /// TODO: DOCUMENT ME
|
||||
// pub fn gep(
|
||||
// &self,
|
||||
// ctx: &CodeGenContext<'ctx, '_>,
|
||||
// struct_ptr: PointerValue<'ctx>,
|
||||
// ) -> PointerValue<'ctx> {
|
||||
// let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to use i32 for GEP like that
|
||||
// unsafe {
|
||||
// ctx.builder
|
||||
// .build_in_bounds_gep(
|
||||
// struct_ptr,
|
||||
// &[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)],
|
||||
// self.name,
|
||||
// )
|
||||
// .unwrap()
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// /// TODO: DOCUMENT ME
|
||||
// pub fn load(
|
||||
// &self,
|
||||
// ctx: &CodeGenContext<'ctx, '_>,
|
||||
// struct_ptr: PointerValue<'ctx>,
|
||||
// ) -> BasicValueEnum<'ctx> {
|
||||
// ctx.builder.build_load(self.gep(ctx, struct_ptr), self.name).unwrap()
|
||||
// }
|
||||
//
|
||||
// /// TODO: DOCUMENT ME
|
||||
// pub fn store<V>(&self, ctx: &CodeGenContext<'ctx, '_>, struct_ptr: PointerValue<'ctx>, value: V)
|
||||
// where
|
||||
// V: BasicValue<'ctx>,
|
||||
// {
|
||||
// ctx.builder.build_store(self.gep(ctx, struct_ptr), value).unwrap();
|
||||
// }
|
||||
// }
|
||||
|
||||
pub struct StructFields<'ctx> {
|
||||
/// Name of the struct.
|
||||
///
|
||||
/// Used for generating names.
|
||||
pub name: &'static str,
|
||||
// type IsInstanceError = String;
|
||||
// type IsInstanceResult = Result<(), IsInstanceError>;
|
||||
|
||||
/// All the [`StructField`]s of this struct.
|
||||
///
|
||||
/// **NOTE:** The index position of a [`StructField`]
|
||||
/// matches the element's [`StructField::index`].
|
||||
pub fields: Vec<StructField<'ctx>>,
|
||||
}
|
||||
// pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> IsInstanceResult
|
||||
// where
|
||||
// A: BasicType<'ctx>,
|
||||
// B: BasicType<'ctx>,
|
||||
// {
|
||||
// let expected = expected.as_basic_type_enum();
|
||||
// let got = got.as_basic_type_enum();
|
||||
|
||||
pub struct StructFieldsBuilder<'ctx> {
|
||||
gep_index_counter: u32,
|
||||
/// Name of the struct to be built.
|
||||
name: &'static str,
|
||||
fields: Vec<StructField<'ctx>>,
|
||||
}
|
||||
// // Put those logic into here,
|
||||
// // otherwise there is always a fallback reporting on any kind of mismatch
|
||||
// match (expected, got) {
|
||||
// (BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => {
|
||||
// if expected.get_bit_width() != got.get_bit_width() {
|
||||
// return Err(format!(
|
||||
// "Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))"
|
||||
// ));
|
||||
// }
|
||||
// }
|
||||
// (expected, got) => {
|
||||
// if expected != got {
|
||||
// return Err(format!("Expected {expected}, got {got}"));
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
impl<'ctx> StructField<'ctx> {
|
||||
/// TODO: DOCUMENT ME
|
||||
pub fn gep(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to use i32 for GEP like that
|
||||
unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(
|
||||
struct_ptr,
|
||||
&[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)],
|
||||
self.name,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO: DOCUMENT ME
|
||||
pub fn load(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
ctx.builder.build_load(self.gep(ctx, struct_ptr), self.name).unwrap()
|
||||
}
|
||||
|
||||
/// TODO: DOCUMENT ME
|
||||
pub fn store<V>(&self, ctx: &CodeGenContext<'ctx, '_>, struct_ptr: PointerValue<'ctx>, value: V)
|
||||
where
|
||||
V: BasicValue<'ctx>,
|
||||
{
|
||||
ctx.builder.build_store(self.gep(ctx, struct_ptr), value).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
type IsInstanceError = String;
|
||||
type IsInstanceResult = Result<(), IsInstanceError>;
|
||||
|
||||
pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> IsInstanceResult
|
||||
where
|
||||
A: BasicType<'ctx>,
|
||||
B: BasicType<'ctx>,
|
||||
{
|
||||
let expected = expected.as_basic_type_enum();
|
||||
let got = got.as_basic_type_enum();
|
||||
|
||||
// Put those logic into here,
|
||||
// otherwise there is always a fallback reporting on any kind of mismatch
|
||||
match (expected, got) {
|
||||
(BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => {
|
||||
if expected.get_bit_width() != got.get_bit_width() {
|
||||
return Err(format!(
|
||||
"Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))"
|
||||
));
|
||||
}
|
||||
}
|
||||
(expected, got) => {
|
||||
if expected != got {
|
||||
return Err(format!("Expected {expected}, got {got}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl<'ctx> StructFields<'ctx> {
|
||||
pub fn num_fields(&self) -> u32 {
|
||||
self.fields.len() as u32
|
||||
}
|
||||
|
||||
pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
|
||||
let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec();
|
||||
ctx.struct_type(llvm_fields.as_slice(), false)
|
||||
}
|
||||
|
||||
pub fn is_type(&self, scrutinee: StructType<'ctx>) -> IsInstanceResult {
|
||||
// Check scrutinee's number of struct fields
|
||||
if scrutinee.count_fields() != self.num_fields() {
|
||||
return Err(format!(
|
||||
"Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}",
|
||||
struct_name = self.name,
|
||||
expected_count = self.num_fields(),
|
||||
got_count = scrutinee.count_fields(),
|
||||
));
|
||||
}
|
||||
|
||||
// Check the scrutinee's field types
|
||||
for field in self.fields.iter() {
|
||||
let expected_field_ty = field.ty;
|
||||
let got_field_ty = scrutinee.get_field_type_at_index(field.gep_index).unwrap();
|
||||
|
||||
if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) {
|
||||
return Err(format!(
|
||||
"Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}",
|
||||
gep_index = field.gep_index,
|
||||
struct_name = self.name,
|
||||
field_name = field.name,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Done
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> StructFieldsBuilder<'ctx> {
|
||||
pub fn start(name: &'static str) -> Self {
|
||||
StructFieldsBuilder { gep_index_counter: 0, name, fields: Vec::new() }
|
||||
}
|
||||
|
||||
pub fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> {
|
||||
let index = self.gep_index_counter;
|
||||
self.gep_index_counter += 1;
|
||||
|
||||
let field = StructField { gep_index: index, name, ty };
|
||||
self.fields.push(field); // Register into self.fields
|
||||
|
||||
field // Return to the caller to conveniently let them do whatever they want
|
||||
}
|
||||
|
||||
pub fn end(self) -> StructFields<'ctx> {
|
||||
StructFields { name: self.name, fields: self.fields }
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Use derppening's abstraction
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct NpArrayType<'ctx> {
|
||||
pub size_type: IntType<'ctx>,
|
||||
pub elem_type: BasicTypeEnum<'ctx>,
|
||||
}
|
||||
|
||||
pub struct NpArrayStructFields<'ctx> {
|
||||
pub whole_struct: StructFields<'ctx>,
|
||||
pub data: StructField<'ctx>,
|
||||
pub itemsize: StructField<'ctx>,
|
||||
pub ndims: StructField<'ctx>,
|
||||
pub shape: StructField<'ctx>,
|
||||
pub strides: StructField<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx> NpArrayType<'ctx> {
|
||||
pub fn new_opaque_elem(
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
size_type: IntType<'ctx>,
|
||||
) -> NpArrayType<'ctx> {
|
||||
NpArrayType { size_type, elem_type: ctx.ctx.i8_type().as_basic_type_enum() }
|
||||
}
|
||||
|
||||
pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
|
||||
self.fields(ctx).whole_struct.get_struct_type(ctx)
|
||||
}
|
||||
|
||||
pub fn fields(&self, ctx: &'ctx Context) -> NpArrayStructFields<'ctx> {
|
||||
let mut builder = StructFieldsBuilder::start("NpArray");
|
||||
|
||||
let addrspace = AddressSpace::default();
|
||||
|
||||
let byte_type = ctx.i8_type();
|
||||
|
||||
// Make sure the struct matches PERFECTLY with that defined in `nac3core/irrt`.
|
||||
let data = builder.add_field("data", byte_type.ptr_type(addrspace).into());
|
||||
let itemsize = builder.add_field("itemsize", self.size_type.into());
|
||||
let ndims = builder.add_field("ndims", self.size_type.into());
|
||||
let shape = builder.add_field("shape", self.size_type.ptr_type(addrspace).into());
|
||||
let strides = builder.add_field("strides", self.size_type.ptr_type(addrspace).into());
|
||||
|
||||
NpArrayStructFields { whole_struct: builder.end(), data, itemsize, ndims, shape, strides }
|
||||
}
|
||||
|
||||
/// Allocate an `ndarray` on stack, with the following notes:
|
||||
///
|
||||
/// - `ndarray.ndims` will be initialized to `in_ndims`.
|
||||
/// - `ndarray.itemsize` will be initialized to the size of `self.elem_type.size_of()`.
|
||||
/// - `ndarray.shape` and `ndarray.strides` will be allocated on the stack with number of elements being `in_ndims`,
|
||||
/// all with empty/uninitialized values.
|
||||
pub fn alloca(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
in_ndims: IntValue<'ctx>,
|
||||
name: &str,
|
||||
) -> NpArrayValue<'ctx> {
|
||||
let ptr = ctx
|
||||
.builder
|
||||
.build_alloca(self.get_struct_type(ctx.ctx).as_basic_type_enum(), name)
|
||||
.unwrap();
|
||||
|
||||
// Allocate `in_dims` number of `size_type` on the stack for `shape` and `strides`
|
||||
let allocated_shape = ctx
|
||||
.builder
|
||||
.build_array_alloca(self.size_type.as_basic_type_enum(), in_ndims, "allocated_shape")
|
||||
.unwrap();
|
||||
let allocated_strides = ctx
|
||||
.builder
|
||||
.build_array_alloca(self.size_type.as_basic_type_enum(), in_ndims, "allocated_strides")
|
||||
.unwrap();
|
||||
|
||||
let value = NpArrayValue { ty: *self, ptr };
|
||||
value.store_ndims(ctx, in_ndims);
|
||||
value.store_itemsize(ctx, self.elem_type.size_of().unwrap());
|
||||
value.store_shape(ctx, allocated_shape);
|
||||
value.store_strides(ctx, allocated_strides);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
pub fn value_from_ptr(
|
||||
&self,
|
||||
ctx: &'ctx Context,
|
||||
in_ndarray_ptr: PointerValue<'ctx>,
|
||||
) -> NpArrayValue<'ctx> {
|
||||
if cfg!(debug_assertions) {
|
||||
// Sanity check on `in_ndarray_ptr`'s type
|
||||
|
||||
let in_ndarray_struct_type =
|
||||
in_ndarray_ptr.get_type().get_element_type().into_struct_type();
|
||||
|
||||
// unwrap to check
|
||||
self.fields(ctx).whole_struct.is_type(in_ndarray_struct_type).unwrap();
|
||||
}
|
||||
NpArrayValue { ty: *self, ptr: in_ndarray_ptr }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct NpArrayValue<'ctx> {
|
||||
pub ty: NpArrayType<'ctx>,
|
||||
pub ptr: PointerValue<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx> NpArrayValue<'ctx> {
|
||||
pub fn load_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
let field = self.ty.fields(ctx.ctx).data;
|
||||
field.load(ctx, self.ptr).into_pointer_value()
|
||||
}
|
||||
|
||||
pub fn store_data(&self, ctx: &CodeGenContext<'ctx, '_>, new_data_ptr: PointerValue<'ctx>) {
|
||||
let field = self.ty.fields(ctx.ctx).data;
|
||||
field.store(ctx, self.ptr, new_data_ptr);
|
||||
}
|
||||
|
||||
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
let field = self.ty.fields(ctx.ctx).ndims;
|
||||
field.load(ctx, self.ptr).into_int_value()
|
||||
}
|
||||
|
||||
pub fn store_ndims(&self, ctx: &CodeGenContext<'ctx, '_>, new_ndims: IntValue<'ctx>) {
|
||||
let field = self.ty.fields(ctx.ctx).ndims;
|
||||
field.store(ctx, self.ptr, new_ndims);
|
||||
}
|
||||
|
||||
pub fn load_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
let field = self.ty.fields(ctx.ctx).itemsize;
|
||||
field.load(ctx, self.ptr).into_int_value()
|
||||
}
|
||||
|
||||
pub fn store_itemsize(&self, ctx: &CodeGenContext<'ctx, '_>, new_itemsize: IntValue<'ctx>) {
|
||||
let field = self.ty.fields(ctx.ctx).itemsize;
|
||||
field.store(ctx, self.ptr, new_itemsize);
|
||||
}
|
||||
|
||||
pub fn load_shape(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
let field = self.ty.fields(ctx.ctx).shape;
|
||||
field.load(ctx, self.ptr).into_pointer_value()
|
||||
}
|
||||
|
||||
pub fn store_shape(&self, ctx: &CodeGenContext<'ctx, '_>, new_shape_ptr: PointerValue<'ctx>) {
|
||||
let field = self.ty.fields(ctx.ctx).shape;
|
||||
field.store(ctx, self.ptr, new_shape_ptr);
|
||||
}
|
||||
|
||||
pub fn load_strides(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
let field = self.ty.fields(ctx.ctx).strides;
|
||||
field.load(ctx, self.ptr).into_pointer_value()
|
||||
}
|
||||
|
||||
pub fn store_strides(&self, ctx: &CodeGenContext<'ctx, '_>, value: PointerValue<'ctx>) {
|
||||
let field = self.ty.fields(ctx.ctx).strides;
|
||||
field.store(ctx, self.ptr, value);
|
||||
}
|
||||
|
||||
/// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!!
|
||||
pub fn shape_slice(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||
// Get the pointer to `shape`
|
||||
let field = self.ty.fields(ctx.ctx).shape;
|
||||
let shape = field.load(ctx, self.ptr).into_pointer_value();
|
||||
|
||||
// Load `ndims`
|
||||
let ndims = self.load_ndims(ctx);
|
||||
|
||||
TypedArrayLikeAdapter {
|
||||
adapted: ArraySliceValue(shape, ndims, Some(field.name)),
|
||||
downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
|
||||
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO: DOCUMENT ME -- NDIMS WOULD NEVER CHANGE!!!!!
|
||||
pub fn strides_slice(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
|
||||
// Get the pointer to `strides`
|
||||
let field = self.ty.fields(ctx.ctx).strides;
|
||||
let strides = field.load(ctx, self.ptr).into_pointer_value();
|
||||
|
||||
// Load `ndims`
|
||||
let ndims = self.load_ndims(ctx);
|
||||
|
||||
TypedArrayLikeAdapter {
|
||||
adapted: ArraySliceValue(strides, ndims, Some(field.name)),
|
||||
downcast_fn: Box::new(|_ctx, x| x.into_int_value()),
|
||||
upcast_fn: Box::new(|_ctx, x| x.as_basic_value_enum()),
|
||||
}
|
||||
}
|
||||
}
|
||||
// impl<'ctx> StructFields<'ctx> {
|
||||
// pub fn num_fields(&self) -> u32 {
|
||||
// self.fields.len() as u32
|
||||
// }
|
||||
//
|
||||
// pub fn get_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
|
||||
// let llvm_fields = self.fields.iter().map(|field| field.ty).collect_vec();
|
||||
// ctx.struct_type(llvm_fields.as_slice(), false)
|
||||
// }
|
||||
//
|
||||
// pub fn is_type(&self, scrutinee: StructType<'ctx>) -> IsInstanceResult {
|
||||
// // Check scrutinee's number of struct fields
|
||||
// if scrutinee.count_fields() != self.num_fields() {
|
||||
// return Err(format!(
|
||||
// "Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}",
|
||||
// struct_name = self.name,
|
||||
// expected_count = self.num_fields(),
|
||||
// got_count = scrutinee.count_fields(),
|
||||
// ));
|
||||
// }
|
||||
//
|
||||
// // Check the scrutinee's field types
|
||||
// for field in self.fields.iter() {
|
||||
// let expected_field_ty = field.ty;
|
||||
// let got_field_ty = scrutinee.get_field_type_at_index(field.gep_index).unwrap();
|
||||
//
|
||||
// if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) {
|
||||
// return Err(format!(
|
||||
// "Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}",
|
||||
// gep_index = field.gep_index,
|
||||
// struct_name = self.name,
|
||||
// field_name = field.name,
|
||||
// ));
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Done
|
||||
// Ok(())
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// impl<'ctx> StructFieldsBuilder<'ctx> {
|
||||
// pub fn start(name: &'static str) -> Self {
|
||||
// StructFieldsBuilder { gep_index_counter: 0, name, fields: Vec::new() }
|
||||
// }
|
||||
//
|
||||
// pub fn add_field(&mut self, name: &'static str, ty: BasicTypeEnum<'ctx>) -> StructField<'ctx> {
|
||||
// let index = self.gep_index_counter;
|
||||
// self.gep_index_counter += 1;
|
||||
//
|
||||
// let field = StructField { gep_index: index, name, ty };
|
||||
// self.fields.push(field); // Register into self.fields
|
||||
//
|
||||
// field // Return to the caller to conveniently let them do whatever they want
|
||||
// }
|
||||
//
|
||||
// pub fn end(self) -> StructFields<'ctx> {
|
||||
// StructFields { name: self.name, fields: self.fields }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
|
|
|
@ -2202,6 +2202,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
|||
let dst_ndims = deduce_ndims_after_slicing(ndims, ndslices.iter());
|
||||
|
||||
// Finally, perform the actual subscript logic
|
||||
// TODO: `call_ndarray_subscript_impl` under the hood deduces `dst_ndims` again. We could save it some time by passing `dst_ndims` - a TODO?
|
||||
let subndarray = call_ndarray_subscript_impl(
|
||||
generator,
|
||||
ctx,
|
||||
|
|
|
@ -0,0 +1,87 @@
|
|||
// TODO: Use derppening's abstraction
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
types::{BasicType, BasicTypeEnum, IntType},
|
||||
values::BasicValueEnum,
|
||||
AddressSpace,
|
||||
};
|
||||
|
||||
use crate::codegen::structure::{
|
||||
CustomStructType, CustomType, Field, FieldCreator, IntType2, Object, PointerType2,
|
||||
PointingArrayType,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct NpArrayType<'ctx> {
|
||||
pub size_type: IntType<'ctx>,
|
||||
pub elem_type: BasicTypeEnum<'ctx>,
|
||||
}
|
||||
|
||||
pub struct NpArrayFields<'ctx> {
|
||||
pub data: Field<'ctx, PointerType2<'ctx>>,
|
||||
pub itemsize: Field<'ctx, IntType2<'ctx>>,
|
||||
pub ndims: Field<'ctx, IntType2<'ctx>>,
|
||||
pub shape: Field<'ctx, PointingArrayType<'ctx, IntType2<'ctx>>>,
|
||||
pub strides: Field<'ctx, PointingArrayType<'ctx, IntType2<'ctx>>>,
|
||||
}
|
||||
|
||||
pub type NpArrayValue<'ctx> = Object<'ctx, NpArrayType<'ctx>>;
|
||||
|
||||
// impl<'ctx> CustomType<'ctx> for NpArrayType<'ctx> {
|
||||
// type Value = NpArrayValue<'ctx>;
|
||||
//
|
||||
// fn llvm_basic_type_enum(
|
||||
// &self,
|
||||
// ctx: &'ctx inkwell::context::Context,
|
||||
// ) -> inkwell::types::BasicTypeEnum<'ctx> {
|
||||
// self.llvm_struct_type(ctx).as_basic_type_enum()
|
||||
// }
|
||||
//
|
||||
// fn llvm_field_load(
|
||||
// &self,
|
||||
// ctx: &crate::codegen::CodeGenContext<'ctx, '_>,
|
||||
// field: crate::codegen::structure::FieldInfo,
|
||||
// struct_ptr: inkwell::values::PointerValue<'ctx>,
|
||||
// ) -> Self::Value {
|
||||
// let ok = field.llvm_load(ctx, struct_ptr);
|
||||
// todo!()
|
||||
// }
|
||||
//
|
||||
// fn llvm_field_store(
|
||||
// &self,
|
||||
// ctx: &crate::codegen::CodeGenContext<'ctx, '_>,
|
||||
// field: crate::codegen::structure::FieldInfo,
|
||||
// struct_ptr: inkwell::values::PointerValue<'ctx>,
|
||||
// value: &Self::Value,
|
||||
// ) {
|
||||
// todo!()
|
||||
// }
|
||||
// }
|
||||
|
||||
impl<'ctx> CustomStructType<'ctx> for NpArrayType<'ctx> {
|
||||
type Fields = NpArrayFields<'ctx>;
|
||||
|
||||
fn llvm_struct_name() -> &'static str {
|
||||
"NDArray"
|
||||
}
|
||||
|
||||
fn add_fields_to(&self, creator: &mut FieldCreator<'ctx>) -> Self::Fields {
|
||||
let pi8 = creator.ctx.i8_type().ptr_type(AddressSpace::default());
|
||||
NpArrayFields {
|
||||
data: creator.add_field("data", PointerType2(pi8)),
|
||||
itemsize: creator.add_field("itemsize", IntType2(self.size_type)),
|
||||
ndims: creator.add_field("ndims", IntType2(self.size_type)),
|
||||
shape: creator.add_field("shape", PointingArrayType::new(IntType2(self.size_type))),
|
||||
strides: creator.add_field("strides", PointingArrayType::new(IntType2(self.size_type))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> NpArrayType<'ctx> {
|
||||
pub fn new_opaque_elem(ctx: &'ctx Context, size_type: IntType<'ctx>) -> Self {
|
||||
NpArrayType { elem_type: ctx.i8_type().into(), size_type }
|
||||
}
|
||||
}
|
|
@ -11,13 +11,15 @@ mod test;
|
|||
use super::{
|
||||
classes::{
|
||||
check_basic_types_match, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue,
|
||||
NDArrayValue, NpArrayType, NpArrayValue, StructField, StructFields, TypedArrayLikeAdapter,
|
||||
UntypedArrayLikeAccessor,
|
||||
NDArrayValue, StructField, StructFields, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
|
||||
},
|
||||
llvm_intrinsics, CodeGenContext, CodeGenerator,
|
||||
llvm_intrinsics,
|
||||
structure::CustomStructType,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
use crate::codegen::classes::TypedArrayLikeAccessor;
|
||||
use crate::codegen::stmt::gen_for_callback_incrementing;
|
||||
use classes::{NpArrayType, NpArrayValue};
|
||||
use crossbeam::channel::IntoIter;
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
|
@ -33,6 +35,7 @@ use inkwell::{
|
|||
};
|
||||
use itertools::Either;
|
||||
use nac3parser::ast::Expr;
|
||||
pub mod classes;
|
||||
|
||||
#[must_use]
|
||||
pub fn load_irrt(ctx: &Context) -> Module {
|
||||
|
@ -957,7 +960,7 @@ pub struct UserSlice<'ctx> {
|
|||
pub step: Option<IntValue<'ctx>>,
|
||||
}
|
||||
|
||||
pub struct IrrtUserSliceStructFields<'ctx> {
|
||||
pub struct IrrtUserSliceTypeStructFields<'ctx> {
|
||||
pub whole_struct: StructFields<'ctx>,
|
||||
|
||||
pub start_defined: StructField<'ctx>,
|
||||
|
@ -971,10 +974,10 @@ pub struct IrrtUserSliceStructFields<'ctx> {
|
|||
}
|
||||
|
||||
// TODO: EMPTY STRUCT
|
||||
struct IrrtUserSlice {}
|
||||
struct IrrtUserSliceType();
|
||||
|
||||
impl IrrtUserSlice {
|
||||
pub fn fields<'ctx>(ctx: &'ctx Context) -> IrrtUserSliceStructFields<'ctx> {
|
||||
impl IrrtUserSliceType {
|
||||
pub fn fields<'ctx>(&self, ctx: &'ctx Context) -> IrrtUserSliceTypeStructFields<'ctx> {
|
||||
let int8 = ctx.i8_type();
|
||||
|
||||
// MUST match the corresponding struct defined in IRRT
|
||||
|
@ -986,7 +989,7 @@ impl IrrtUserSlice {
|
|||
let step_defined = builder.add_field("step_defined", int8.into());
|
||||
let step = builder.add_field("step", get_sliceindex_type(ctx).into());
|
||||
|
||||
IrrtUserSliceStructFields {
|
||||
IrrtUserSliceTypeStructFields {
|
||||
start_defined,
|
||||
start,
|
||||
stop_defined,
|
||||
|
@ -998,11 +1001,12 @@ impl IrrtUserSlice {
|
|||
}
|
||||
|
||||
pub fn alloca_user_slice<'ctx>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
user_slice: &UserSlice<'ctx>,
|
||||
) -> PointerValue<'ctx> {
|
||||
// Derive the struct_type
|
||||
let fields = Self::fields(ctx.ctx);
|
||||
let fields = self.fields(ctx.ctx);
|
||||
let struct_type = fields.whole_struct.get_struct_type(ctx.ctx);
|
||||
|
||||
// ...and then allocate for a real `UserSlice` in LLVM
|
||||
|
@ -1070,23 +1074,22 @@ where
|
|||
}
|
||||
|
||||
// TODO: Empty struct
|
||||
pub struct IrrtNDSlice {}
|
||||
|
||||
pub struct IrrtNDSliceStructFields<'ctx> {
|
||||
pub struct IrrtNDSliceType();
|
||||
pub struct IrrtNDSliceTypeStructFields<'ctx> {
|
||||
pub whole_struct: StructFields<'ctx>,
|
||||
pub type_: StructField<'ctx>,
|
||||
pub slice: StructField<'ctx>,
|
||||
}
|
||||
|
||||
impl IrrtNDSlice {
|
||||
pub fn fields<'ctx>(ctx: &'ctx Context) -> IrrtNDSliceStructFields<'ctx> {
|
||||
impl IrrtNDSliceType {
|
||||
pub fn fields<'ctx>(ctx: &'ctx Context) -> IrrtNDSliceTypeStructFields<'ctx> {
|
||||
let mut builder = StructFieldsBuilder::start("NDSlice");
|
||||
|
||||
// MUST match the corresponding struct defined in IRRT
|
||||
let type_ = builder.add_field("type", get_ndslicetype_constant_type(ctx).into());
|
||||
let slice = builder.add_field("slice", get_opaque_uint8_ptr_type(ctx).into());
|
||||
|
||||
IrrtNDSliceStructFields { type_, slice, whole_struct: builder.end() }
|
||||
IrrtNDSliceTypeStructFields { type_, slice, whole_struct: builder.end() }
|
||||
}
|
||||
|
||||
pub fn alloca_ndslices<'ctx>(
|
||||
|
@ -1124,7 +1127,7 @@ impl IrrtNDSlice {
|
|||
}
|
||||
NDSlice::Slice(user_slice) => {
|
||||
// Allocate the user_slice
|
||||
let slice_ptr = IrrtUserSlice::alloca_user_slice(ctx, user_slice);
|
||||
let slice_ptr = IrrtUserSliceType().alloca_user_slice(ctx, user_slice);
|
||||
|
||||
let type_ = 1; // const NDSliceType INPUT_SLICE_TYPE_SLICE = 1;
|
||||
(type_, slice_ptr)
|
||||
|
@ -1156,6 +1159,84 @@ impl IrrtNDSlice {
|
|||
}
|
||||
}
|
||||
|
||||
struct IrrtPrinterType();
|
||||
struct IrrtPrinterTypeStructFields<'ctx> {
|
||||
whole_struct: StructFields<'ctx>,
|
||||
string_base_ptr: StructField<'ctx>,
|
||||
max_length: StructField<'ctx>,
|
||||
length: StructField<'ctx>,
|
||||
}
|
||||
|
||||
impl IrrtPrinterType {
|
||||
pub fn fields<'ctx>(&self, ctx: &'ctx Context) -> IrrtPrinterTypeStructFields<'ctx> {
|
||||
let mut builder = StructFieldsBuilder::start("Printer");
|
||||
|
||||
let string_base_ptr = builder
|
||||
.add_field("string_base_ptr", ctx.i8_type().ptr_type(AddressSpace::default()).into());
|
||||
let max_length = builder.add_field("max_length", ctx.i32_type().into());
|
||||
let length = builder.add_field("length", ctx.i32_type().into());
|
||||
|
||||
IrrtPrinterTypeStructFields {
|
||||
string_base_ptr,
|
||||
max_length,
|
||||
length,
|
||||
whole_struct: builder.end(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct IrrtPrinterValue<'ctx> {
|
||||
ty: IrrtPrinterType,
|
||||
ptr: PointerValue<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx> IrrtPrinterValue<'ctx> {
|
||||
pub fn hl(&self) {}
|
||||
}
|
||||
|
||||
struct IrrtErrorContextType();
|
||||
struct IrrtErrorContextTypeStructFields<'ctx> {
|
||||
whole_struct: StructFields<'ctx>,
|
||||
error: StructField<'ctx>,
|
||||
}
|
||||
|
||||
impl IrrtErrorContextType {
|
||||
pub fn fields<'ctx>(&self, ctx: &'ctx Context) -> IrrtErrorContextTypeStructFields<'ctx> {
|
||||
let mut builder = StructFieldsBuilder::start("ErrorContext");
|
||||
|
||||
let error = builder.add_field(
|
||||
"error",
|
||||
IrrtPrinterType().fields(ctx).whole_struct.get_struct_type(ctx).into(),
|
||||
);
|
||||
|
||||
IrrtErrorContextTypeStructFields { error, whole_struct: builder.end() }
|
||||
}
|
||||
|
||||
pub fn alloca<'ctx>(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
let fields = self.fields(ctx.ctx);
|
||||
let struct_type = fields.whole_struct.get_struct_type(ctx.ctx);
|
||||
|
||||
ctx.builder.build_alloca(struct_type, "error_context").unwrap()
|
||||
}
|
||||
|
||||
pub fn load_error<'ctx>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
) -> IrrtPrinterValue<'ctx> {
|
||||
let error = self.fields(ctx.ctx).error;
|
||||
|
||||
IrrtPrinterValue {
|
||||
ty: IrrtPrinterType(),
|
||||
ptr: error.load(ctx, struct_ptr).into_pointer_value(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// struct IrrtErrorContextValue<'ctx> {
|
||||
// ty: IrrtErrorContextValue
|
||||
// }
|
||||
|
||||
fn get_size_variant<'ctx>(ty: IntType<'ctx>) -> SizeVariant {
|
||||
match ty.get_bit_width() {
|
||||
32 => SizeVariant::Bits32,
|
||||
|
@ -1198,8 +1279,7 @@ pub fn get_irrt_ndarray_ptr_type<'ctx>(
|
|||
let i8_type = ctx.i8_type();
|
||||
|
||||
let ndarray_ty = NpArrayType { size_type, elem_type: i8_type.as_basic_type_enum() };
|
||||
let struct_ty = ndarray_ty.get_struct_type(ctx);
|
||||
struct_ty.ptr_type(AddressSpace::default())
|
||||
ndarray_ty.llvm_struct_type(ctx).ptr_type(AddressSpace::default())
|
||||
}
|
||||
|
||||
pub fn get_opaque_uint8_ptr_type<'ctx>(ctx: &'ctx Context) -> PointerType<'ctx> {
|
||||
|
@ -1317,7 +1397,7 @@ pub fn call_nac3_ndarray_deduce_ndims_after_slicing<'ctx>(
|
|||
&[
|
||||
size_type.into(), // SizeT ndims
|
||||
size_type.into(), // SizeT num_slices
|
||||
IrrtNDSlice::fields(ctx.ctx)
|
||||
IrrtNDSliceType::fields(ctx.ctx)
|
||||
.whole_struct
|
||||
.get_struct_type(ctx.ctx)
|
||||
.ptr_type(AddressSpace::default())
|
||||
|
@ -1360,7 +1440,7 @@ pub fn call_nac3_ndarray_subscript<'ctx>(
|
|||
&[
|
||||
get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray<SizeT>* ndarray
|
||||
size_type.into(), // SizeT num_slices
|
||||
IrrtNDSlice::fields(ctx.ctx)
|
||||
IrrtNDSliceType::fields(ctx.ctx)
|
||||
.whole_struct
|
||||
.get_struct_type(ctx.ctx)
|
||||
.ptr_type(AddressSpace::default())
|
||||
|
@ -1393,19 +1473,14 @@ pub fn call_nac3_len<'ctx>(
|
|||
let size_type = ndarray.ty.size_type;
|
||||
|
||||
// Get the IRRT function
|
||||
let function = get_size_type_dependent_function(
|
||||
ctx,
|
||||
size_type,
|
||||
"__nac3_ndarray_len",
|
||||
|| {
|
||||
get_sliceindex_type(ctx.ctx).fn_type(
|
||||
&[
|
||||
get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray<SizeT> *ndarray
|
||||
],
|
||||
false,
|
||||
)
|
||||
},
|
||||
);
|
||||
let function = get_size_type_dependent_function(ctx, size_type, "__nac3_ndarray_len", || {
|
||||
get_sliceindex_type(ctx.ctx).fn_type(
|
||||
&[
|
||||
get_irrt_ndarray_ptr_type(ctx.ctx, size_type).into(), // NDArray<SizeT> *ndarray
|
||||
],
|
||||
false,
|
||||
)
|
||||
});
|
||||
|
||||
// Call the IRRT function
|
||||
ctx.builder
|
||||
|
|
|
@ -44,6 +44,7 @@ pub mod irrt;
|
|||
pub mod llvm_intrinsics;
|
||||
pub mod numpy;
|
||||
pub mod stmt;
|
||||
pub mod structure;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
|
|
@ -38,7 +38,7 @@ use super::{
|
|||
irrt::{
|
||||
call_nac3_ndarray_deduce_ndims_after_slicing, call_nac3_ndarray_set_strides_by_shape,
|
||||
call_nac3_ndarray_size, call_nac3_ndarray_subscript, get_irrt_ndarray_ptr_type,
|
||||
get_opaque_uint8_ptr_type, IrrtNDSlice, NDSlice,
|
||||
get_opaque_uint8_ptr_type, IrrtNDSliceType, NDSlice,
|
||||
},
|
||||
stmt::gen_return,
|
||||
};
|
||||
|
@ -2364,7 +2364,7 @@ where
|
|||
let num_slices = size_type.const_int(ndslices.len() as u64, false);
|
||||
|
||||
// Prepare the argument `slices`
|
||||
let ndslices_ptr = IrrtNDSlice::alloca_ndslices(ctx, ndslices);
|
||||
let ndslices_ptr = IrrtNDSliceType::alloca_ndslices(ctx, ndslices);
|
||||
|
||||
// Get `dst_ndims`
|
||||
let dst_ndims = call_nac3_ndarray_deduce_ndims_after_slicing(
|
||||
|
|
|
@ -0,0 +1,318 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
types::{BasicType, BasicTypeEnum, IntType, PointerType, StructType},
|
||||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue},
|
||||
AddressSpace,
|
||||
};
|
||||
|
||||
use super::CodeGenContext;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct FieldInfo {
|
||||
gep_index: u32,
|
||||
name: &'static str,
|
||||
}
|
||||
|
||||
impl FieldInfo {
|
||||
pub fn llvm_gep<'ctx>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
) -> PointerValue<'ctx> {
|
||||
let index_type = ctx.ctx.i32_type(); // TODO: I think I'm not supposed to *just* use i32 for GEP like that
|
||||
unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(
|
||||
struct_ptr,
|
||||
&[index_type.const_zero(), index_type.const_int(self.gep_index as u64, false)],
|
||||
self.name,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn llvm_load<'ctx>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
// We will use `self.name` as the LLVM label for debugging purposes
|
||||
ctx.builder.build_load(self.llvm_gep(ctx, struct_ptr), self.name).unwrap()
|
||||
}
|
||||
|
||||
pub fn llvm_store<'ctx>(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
value: BasicValueEnum<'ctx>,
|
||||
) {
|
||||
ctx.builder.build_store(self.llvm_gep(ctx, struct_ptr), value).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Object<'ctx, T> {
|
||||
pub ty: T,
|
||||
pub ptr: PointerValue<'ctx>,
|
||||
}
|
||||
|
||||
pub struct Field<'ctx, T: CustomType<'ctx>> {
|
||||
pub info: FieldInfo,
|
||||
pub ty: T,
|
||||
_phantom: PhantomData<&'ctx ()>,
|
||||
}
|
||||
|
||||
pub struct FieldCreator<'ctx> {
|
||||
pub ctx: &'ctx Context,
|
||||
struct_name: &'ctx str,
|
||||
gep_index_counter: u32,
|
||||
fields: Vec<(FieldInfo, BasicTypeEnum<'ctx>)>,
|
||||
}
|
||||
|
||||
impl<'ctx> FieldCreator<'ctx> {
|
||||
pub fn new(ctx: &'ctx Context, struct_name: &'ctx str) -> Self {
|
||||
FieldCreator { ctx, struct_name, gep_index_counter: 0, fields: Vec::new() }
|
||||
}
|
||||
|
||||
fn next_gep_index(&mut self) -> u32 {
|
||||
let index = self.gep_index_counter;
|
||||
self.gep_index_counter += 1;
|
||||
index
|
||||
}
|
||||
|
||||
fn get_struct_field_types(&self) -> Vec<BasicTypeEnum<'ctx>> {
|
||||
self.fields.iter().map(|x| x.1.clone()).collect()
|
||||
}
|
||||
|
||||
pub fn add_field<T: CustomType<'ctx>>(&mut self, name: &'static str, ty: T) -> Field<'ctx, T> {
|
||||
let gep_index = self.next_gep_index();
|
||||
|
||||
let field_type = ty.llvm_basic_type_enum(self.ctx);
|
||||
let field_info = FieldInfo { gep_index, name };
|
||||
let field = Field { info: field_info, ty, _phantom: PhantomData };
|
||||
|
||||
self.fields.push((field_info.clone(), field_type));
|
||||
|
||||
field
|
||||
}
|
||||
|
||||
fn num_fields(&self) -> u32 {
|
||||
self.fields.len() as u32 // casted to u32 because that is what inkwell returns
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomType<'ctx>: Clone {
|
||||
type Value;
|
||||
|
||||
fn llvm_basic_type_enum(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx>;
|
||||
|
||||
fn llvm_field_load(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
field: FieldInfo,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
) -> Self::Value;
|
||||
|
||||
fn llvm_field_store(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
field: FieldInfo,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
value: &Self::Value,
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct IntType2<'ctx>(pub IntType<'ctx>);
|
||||
|
||||
impl<'ctx> CustomType<'ctx> for IntType2<'ctx> {
|
||||
type Value = IntValue<'ctx>;
|
||||
|
||||
fn llvm_basic_type_enum(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||
self.0.as_basic_type_enum()
|
||||
}
|
||||
|
||||
fn llvm_field_load(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
field: FieldInfo,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
) -> Self::Value {
|
||||
let int_value = field.llvm_load(ctx, struct_ptr).into_int_value();
|
||||
assert_eq!(int_value.get_type().get_bit_width(), self.0.get_bit_width());
|
||||
int_value
|
||||
}
|
||||
|
||||
fn llvm_field_store(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
field: FieldInfo,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
int_value: &Self::Value,
|
||||
) {
|
||||
assert_eq!(int_value.get_type().get_bit_width(), self.0.get_bit_width());
|
||||
field.llvm_store(ctx, struct_ptr, int_value.as_basic_value_enum());
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PointerType2<'ctx>(pub PointerType<'ctx>);
|
||||
|
||||
impl<'ctx> CustomType<'ctx> for PointerType2<'ctx> {
|
||||
type Value = PointerValue<'ctx>;
|
||||
|
||||
fn llvm_basic_type_enum(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||
self.0.as_basic_type_enum()
|
||||
}
|
||||
|
||||
fn llvm_field_load(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
field: FieldInfo,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
) -> Self::Value {
|
||||
field.llvm_load(ctx, struct_ptr).into_pointer_value()
|
||||
}
|
||||
|
||||
fn llvm_field_store(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
field: FieldInfo,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
pointer_value: &Self::Value,
|
||||
) {
|
||||
field.llvm_store(ctx, struct_ptr, pointer_value.as_basic_value_enum());
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PointingArrayType<'ctx, ElementType: CustomType<'ctx>> {
|
||||
pub element_type: ElementType,
|
||||
_phantom: PhantomData<&'ctx ()>,
|
||||
}
|
||||
|
||||
impl<'ctx, ElementType: CustomType<'ctx>> PointingArrayType<'ctx, ElementType> {
|
||||
pub fn new(element_type: ElementType) -> Self {
|
||||
PointingArrayType { element_type, _phantom: PhantomData }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, Element: CustomType<'ctx>> CustomType<'ctx> for PointingArrayType<'ctx, Element> {
|
||||
type Value = Object<'ctx, Self>;
|
||||
|
||||
fn llvm_basic_type_enum(&self, ctx: &'ctx Context) -> BasicTypeEnum<'ctx> {
|
||||
// Element*
|
||||
self.element_type
|
||||
.llvm_basic_type_enum(ctx)
|
||||
.ptr_type(AddressSpace::default())
|
||||
.as_basic_type_enum()
|
||||
}
|
||||
|
||||
fn llvm_field_load(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
field: FieldInfo,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
) -> Self::Value {
|
||||
// Remember that it is just a pointer
|
||||
Object { ty: self.clone(), ptr: field.llvm_load(ctx, struct_ptr).into_pointer_value() }
|
||||
}
|
||||
|
||||
fn llvm_field_store(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
field: FieldInfo,
|
||||
struct_ptr: PointerValue<'ctx>,
|
||||
value: &Self::Value,
|
||||
) {
|
||||
// Remember that it is just a pointer
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_basic_types_match<'ctx, A, B>(expected: A, got: B) -> Result<(), String>
|
||||
where
|
||||
A: BasicType<'ctx>,
|
||||
B: BasicType<'ctx>,
|
||||
{
|
||||
let expected = expected.as_basic_type_enum();
|
||||
let got = got.as_basic_type_enum();
|
||||
|
||||
// Put those logic into here,
|
||||
// otherwise there is always a fallback reporting on any kind of mismatch
|
||||
match (expected, got) {
|
||||
(BasicTypeEnum::IntType(expected), BasicTypeEnum::IntType(got)) => {
|
||||
if expected.get_bit_width() != got.get_bit_width() {
|
||||
return Err(format!(
|
||||
"Expected IntType ({expected}-bit(s)), got IntType ({got}-bit(s))"
|
||||
));
|
||||
}
|
||||
}
|
||||
(expected, got) => {
|
||||
if expected != got {
|
||||
return Err(format!("Expected {expected}, got {got}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub trait CustomStructType<'ctx> {
|
||||
type Fields;
|
||||
|
||||
fn llvm_struct_name() -> &'static str;
|
||||
|
||||
fn add_fields_to(&self, creator: &mut FieldCreator<'ctx>) -> Self::Fields;
|
||||
|
||||
fn fields(&self, ctx: &'ctx Context) -> Self::Fields {
|
||||
let mut creator = FieldCreator::new(ctx, Self::llvm_struct_name());
|
||||
let fields = self.add_fields_to(&mut creator);
|
||||
fields
|
||||
}
|
||||
|
||||
fn llvm_struct_type(&self, ctx: &'ctx Context) -> StructType<'ctx> {
|
||||
let mut creator = FieldCreator::new(ctx, Self::llvm_struct_name());
|
||||
self.add_fields_to(&mut creator);
|
||||
|
||||
ctx.struct_type(&creator.get_struct_field_types(), false)
|
||||
}
|
||||
|
||||
fn check_struct_type(
|
||||
&self,
|
||||
ctx: &'ctx Context,
|
||||
scrutinee: StructType<'ctx>,
|
||||
) -> Result<(), String> {
|
||||
let mut creator = FieldCreator::new(ctx, Self::llvm_struct_name());
|
||||
self.add_fields_to(&mut creator);
|
||||
|
||||
// Check scrutinee's number of struct fields
|
||||
let expected_field_count = creator.num_fields();
|
||||
let got_field_count = scrutinee.count_fields();
|
||||
if got_field_count != expected_field_count {
|
||||
return Err(format!(
|
||||
"Expected {expected_count} field(s) in `{struct_name}` type, got {got_count}",
|
||||
struct_name = Self::llvm_struct_name(),
|
||||
expected_count = expected_field_count,
|
||||
got_count = got_field_count,
|
||||
));
|
||||
}
|
||||
|
||||
// Check the scrutinee's field types
|
||||
for (field_info, expected_field_ty) in creator.fields {
|
||||
let got_field_ty = scrutinee.get_field_type_at_index(field_info.gep_index).unwrap();
|
||||
|
||||
if let Err(field_err) = check_basic_types_match(expected_field_ty, got_field_ty) {
|
||||
return Err(format!(
|
||||
"Field GEP index {gep_index} does not match the expected type of ({struct_name}::{field_name}): {field_err}",
|
||||
gep_index = field_info.gep_index,
|
||||
struct_name = Self::llvm_struct_name(),
|
||||
field_name = field_info.name,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Done
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
use std::iter::once;
|
||||
|
||||
use crate::{codegen::classes::NpArrayType, util::SizeVariant};
|
||||
use crate::util::SizeVariant;
|
||||
use classes::NpArrayType;
|
||||
use helper::{debug_assert_prim_is_allowed, make_exception_fields, PrimDefDetails};
|
||||
use indexmap::IndexMap;
|
||||
use inkwell::{
|
||||
|
@ -1466,10 +1467,13 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
let ndarray_ptr = arg.into_pointer_value(); // It has to be an ndarray
|
||||
|
||||
let size_type = generator.get_size_type(ctx.ctx);
|
||||
let ndarray_ty = NpArrayType::new_opaque_elem(ctx, size_type); // We don't need to care about the element type - we only want the shape
|
||||
let ndarray_ty = NpArrayType::new_opaque_elem(ctx.ctx, size_type); // We don't need to care about the element type - we only want the shape
|
||||
let ndarray = ndarray_ty.value_from_ptr(ctx.ctx, ndarray_ptr);
|
||||
|
||||
Some(call_nac3_len(ctx, ndarray).as_basic_value_enum())
|
||||
let result = call_nac3_len(ctx, ndarray).as_basic_value_enum();
|
||||
Some(result)
|
||||
|
||||
// Some(.as_basic_value_enum())
|
||||
|
||||
// let llvm_i32 = ctx.ctx.i32_type();
|
||||
// let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
|
Loading…
Reference in New Issue