Compare commits

..

8 Commits

Author SHA1 Message Date
David Mak 4764e60011 [core] WIP - Implemented construct_* for NDArrays 2024-11-13 14:46:30 +08:00
David Mak 04cb60564b [core] codegen/types: Implement NDArray in terms of i8*
Better aligns with the future implementation of ndstrides.
2024-11-12 17:13:19 +08:00
David Mak a4fe4c5ca8 [core] irrt: Break IRRT into several impl files
Each IRRT file is now mapped to one Rust file.
2024-11-12 14:55:05 +08:00
David Mak b230ab5ee5 [core] irrt: Update some IRRT implementation
- Change CSlice to use `void*` for better pointer compatibility
- Remove __STDC_VERSION__ guard
- Only include impl *.hpp files in irrt.cpp
- Refactor typedef to using declaration
- Add missing ``// namespace`
2024-11-12 14:55:04 +08:00
David Mak 6e78513c43 [core] codegen: Add dtype to NDArrayType
We won't have this once NDArray is refactored to strided impl.
2024-11-11 20:56:27 +08:00
David Mak 4573490647 [core] codegen: Add Self::llvm_type to all type abstractions 2024-11-11 20:56:27 +08:00
lyken c4f09a49c7 core/irrt: fix exception.hpp C++ castings 2024-11-11 20:56:27 +08:00
lyken 94f11c37c3 core/toplevel/helper: add {extract,create}_ndims 2024-11-11 20:56:27 +08:00
4 changed files with 41 additions and 264 deletions

View File

@ -11,7 +11,6 @@ pub use range::*;
mod list;
mod ndarray;
mod range;
pub mod structure;
/// A LLVM type that is used to represent a corresponding type in NAC3.
pub trait ProxyType<'ctx>: Into<Self::Base> {

View File

@ -1,15 +1,12 @@
use inkwell::{
context::Context,
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
values::{IntValue, PointerValue},
values::IntValue,
AddressSpace,
};
use itertools::Itertools;
use super::{
structure::{FieldIndexCounter, StructField, StructFields},
ProxyType,
};
use super::ProxyType;
use crate::{
codegen::{
values::{ArraySliceValue, NDArrayValue, ProxyValue, TypedArrayLikeMutator},
@ -29,51 +26,6 @@ pub struct NDArrayType<'ctx> {
llvm_usize: IntType<'ctx>,
}
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct NDArrayStructFields<'ctx> {
pub data: StructField<'ctx, PointerValue<'ctx>>,
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
pub ndims: StructField<'ctx, IntValue<'ctx>>,
pub shape: StructField<'ctx, PointerValue<'ctx>>,
pub strides: StructField<'ctx, PointerValue<'ctx>>,
}
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self {
let mut counter = FieldIndexCounter::default();
NDArrayStructFields {
data: StructField::create(
&mut counter,
"data",
ctx.i8_type().ptr_type(AddressSpace::default()),
),
itemsize: StructField::create(&mut counter, "itemsize", llvm_usize),
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
shape: StructField::create(
&mut counter,
"shape",
llvm_usize.ptr_type(AddressSpace::default()),
),
strides: StructField::create(
&mut counter,
"strides",
llvm_usize.ptr_type(AddressSpace::default()),
),
}
}
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
vec![
self.data.into(),
self.itemsize.into(),
self.ndims.into(),
self.shape.into(),
self.strides.into(),
]
}
}
impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_representable(
@ -179,21 +131,27 @@ impl<'ctx> NDArrayType<'ctx> {
// TODO: Move this into e.g. StructProxyType
#[must_use]
fn fields(
fn layout(
ctx: &'ctx Context,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
NDArrayStructFields::new(ctx, llvm_usize)
) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
vec![
("data", ctx.i8_type().ptr_type(AddressSpace::default()).into()),
("itemsize", llvm_usize.into()),
("ndims", llvm_usize.into()),
("shape", llvm_usize.ptr_type(AddressSpace::default()).into()),
("strides", llvm_usize.ptr_type(AddressSpace::default()).into()),
]
}
// TODO: Move this into e.g. StructProxyType
#[must_use]
pub fn get_fields(
pub fn get_layout(
&self,
ctx: &'ctx Context,
llvm_usize: IntType<'ctx>,
) -> NDArrayStructFields<'ctx> {
Self::fields(ctx, llvm_usize)
) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
Self::layout(ctx, llvm_usize)
}
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
@ -206,10 +164,8 @@ impl<'ctx> NDArrayType<'ctx> {
// * ndims : Number of dimensions in the array
// * shape : Pointer to an array containing the shape of the NDArray
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
let field_tys = Self::fields(ctx, llvm_usize)
.into_iter()
.map(|field| field.1)
.collect_vec();
let field_tys =
Self::layout(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
}

View File

@ -1,194 +0,0 @@
use std::marker::PhantomData;
use inkwell::{
context::Context,
types::{BasicTypeEnum, IntType},
values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
};
use crate::codegen::CodeGenContext;
/// Trait indicating that the structure is a field-wise representation of an LLVM structure.
///
/// # Usage
///
/// For example, for a simple C-slice LLVM structure:
///
/// ```ignore
/// struct CSliceFields<'ctx> {
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
/// len: StructField<'ctx, IntValue<'ctx>>
/// }
/// ```
pub trait StructFields<'ctx>: Eq + Copy {
/// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types.
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self;
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
/// the type definition.
#[must_use]
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>;
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
/// in the type definition.
#[must_use]
fn iter(&self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)> {
self.to_vec().into_iter()
}
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
/// the type definition.
#[must_use]
fn into_vec(self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>
where
Self: Sized,
{
self.to_vec()
}
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
/// in the type definition.
#[must_use]
fn into_iter(self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)>
where
Self: Sized,
{
self.into_vec().into_iter()
}
}
/// A single field of an LLVM structure.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct StructField<'ctx, Value>
where
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
{
/// The index of this field within the structure.
index: u32,
/// The name of this field.
name: &'static str,
/// The type of this field.
ty: BasicTypeEnum<'ctx>,
/// Instance of [`PhantomData`] containing [`Value`], used to implement automatic downcasts.
_value_ty: PhantomData<Value>,
}
impl<'ctx, Value> StructField<'ctx, Value>
where
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
{
/// Creates an instance of [`StructField`].
///
/// * `idx_counter` - The instance of [`FieldIndexCounter`] used to track the current field
/// index.
/// * `name` - Name of the field.
/// * `ty` - The type of this field.
pub(super) fn create(
idx_counter: &mut FieldIndexCounter,
name: &'static str,
ty: impl Into<BasicTypeEnum<'ctx>>,
) -> Self {
StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData }
}
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
/// {idx...}, i32 {self.index}`.
pub fn ptr_by_array_gep(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
idx: &[IntValue<'ctx>],
) -> PointerValue<'ctx> {
unsafe {
ctx.builder.build_in_bounds_gep(
pobj,
&[idx, &[ctx.ctx.i32_type().const_int(u64::from(self.index), false)]].concat(),
"",
)
}
.unwrap()
}
/// Creates a pointer to this field in an arbitrary structure by performing the equivalent of
/// `getelementptr i32 0, i32 {self.index}`.
pub fn ptr_by_gep(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
obj_name: Option<&'ctx str>,
) -> PointerValue<'ctx> {
ctx.builder
.build_struct_gep(
pobj,
self.index,
&obj_name.map(|name| format!("{name}.{}.addr", self.name)).unwrap_or_default(),
)
.unwrap()
}
/// Gets the value of this field for a given `obj`.
#[must_use]
pub fn get_from_value(&self, obj: StructValue<'ctx>) -> Value {
obj.get_field_at_index(self.index).and_then(|value| Value::try_from(value).ok()).unwrap()
}
/// Sets the value of this field for a given `obj`.
pub fn set_from_value(&self, obj: StructValue<'ctx>, value: Value) {
obj.set_field_at_index(self.index, value);
}
/// Gets the value of this field for a pointer-to-structure.
pub fn get(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
obj_name: Option<&'ctx str>,
) -> Value {
ctx.builder
.build_load(
self.ptr_by_gep(ctx, pobj, obj_name),
&obj_name.map(|name| format!("{name}.{}", self.name)).unwrap_or_default(),
)
.map_err(|_| ())
.and_then(|value| Value::try_from(value))
.unwrap()
}
/// Sets the value of this field for a pointer-to-structure.
pub fn set(
&self,
ctx: &CodeGenContext<'ctx, '_>,
pobj: PointerValue<'ctx>,
value: Value,
obj_name: Option<&'ctx str>,
) {
ctx.builder.build_store(self.ptr_by_gep(ctx, pobj, obj_name), value).unwrap();
}
}
impl<'ctx, Value> From<StructField<'ctx, Value>> for (&'static str, BasicTypeEnum<'ctx>)
where
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
{
fn from(value: StructField<'ctx, Value>) -> Self {
(value.name, value.ty)
}
}
/// A counter that tracks the next index of a field using a monotonically increasing counter.
#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
pub(super) struct FieldIndexCounter(u32);
impl FieldIndexCounter {
/// Increments the number stored by this counter, returning the previous value.
///
/// Functionally equivalent to `i++` in C-based languages.
pub fn increment(&mut self) -> u32 {
let v = self.0;
self.0 += 1;
v
}
}

View File

@ -13,7 +13,7 @@ use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
types::{structure::StructFields, NDArrayType},
types::NDArrayType,
CodeGenContext, CodeGenerator,
};
@ -59,7 +59,7 @@ impl<'ctx> NDArrayValue<'ctx> {
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.get_layout(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "data")
.unwrap()
@ -112,10 +112,26 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
self.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.ndims
.ptr_by_gep(ctx, self.as_base_value(), self.name)
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
let field_offset = self
.get_type()
.get_layout(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "ndims")
.unwrap()
.0 as u64;
unsafe {
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, false)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the number of dimensions `ndims` into this instance.
@ -144,7 +160,7 @@ impl<'ctx> NDArrayValue<'ctx> {
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.get_layout(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "itemsize")
.unwrap()
@ -188,7 +204,7 @@ impl<'ctx> NDArrayValue<'ctx> {
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.get_layout(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "shape")
.unwrap()
@ -234,7 +250,7 @@ impl<'ctx> NDArrayValue<'ctx> {
let field_offset = self
.get_type()
.get_fields(ctx.ctx, self.llvm_usize)
.get_layout(ctx.ctx, self.llvm_usize)
.into_iter()
.find_position(|field| field.0 == "strides")
.unwrap()