Compare commits
8 Commits
2b3e26c421
...
4764e60011
Author | SHA1 | Date |
---|---|---|
David Mak | 4764e60011 | |
David Mak | 04cb60564b | |
David Mak | a4fe4c5ca8 | |
David Mak | b230ab5ee5 | |
David Mak | 6e78513c43 | |
David Mak | 4573490647 | |
lyken | c4f09a49c7 | |
lyken | 94f11c37c3 |
|
@ -11,7 +11,6 @@ pub use range::*;
|
|||
mod list;
|
||||
mod ndarray;
|
||||
mod range;
|
||||
pub mod structure;
|
||||
|
||||
/// A LLVM type that is used to represent a corresponding type in NAC3.
|
||||
pub trait ProxyType<'ctx>: Into<Self::Base> {
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
use inkwell::{
|
||||
context::Context,
|
||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||
values::{IntValue, PointerValue},
|
||||
values::IntValue,
|
||||
AddressSpace,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
|
||||
use super::{
|
||||
structure::{FieldIndexCounter, StructField, StructFields},
|
||||
ProxyType,
|
||||
};
|
||||
use super::ProxyType;
|
||||
use crate::{
|
||||
codegen::{
|
||||
values::{ArraySliceValue, NDArrayValue, ProxyValue, TypedArrayLikeMutator},
|
||||
|
@ -29,51 +26,6 @@ pub struct NDArrayType<'ctx> {
|
|||
llvm_usize: IntType<'ctx>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
pub struct NDArrayStructFields<'ctx> {
|
||||
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
||||
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
|
||||
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
||||
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
||||
pub strides: StructField<'ctx, PointerValue<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
|
||||
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self {
|
||||
let mut counter = FieldIndexCounter::default();
|
||||
|
||||
NDArrayStructFields {
|
||||
data: StructField::create(
|
||||
&mut counter,
|
||||
"data",
|
||||
ctx.i8_type().ptr_type(AddressSpace::default()),
|
||||
),
|
||||
itemsize: StructField::create(&mut counter, "itemsize", llvm_usize),
|
||||
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
|
||||
shape: StructField::create(
|
||||
&mut counter,
|
||||
"shape",
|
||||
llvm_usize.ptr_type(AddressSpace::default()),
|
||||
),
|
||||
strides: StructField::create(
|
||||
&mut counter,
|
||||
"strides",
|
||||
llvm_usize.ptr_type(AddressSpace::default()),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
|
||||
vec![
|
||||
self.data.into(),
|
||||
self.itemsize.into(),
|
||||
self.ndims.into(),
|
||||
self.shape.into(),
|
||||
self.strides.into(),
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> NDArrayType<'ctx> {
|
||||
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
|
||||
pub fn is_representable(
|
||||
|
@ -179,21 +131,27 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||
|
||||
// TODO: Move this into e.g. StructProxyType
|
||||
#[must_use]
|
||||
fn fields(
|
||||
fn layout(
|
||||
ctx: &'ctx Context,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> NDArrayStructFields<'ctx> {
|
||||
NDArrayStructFields::new(ctx, llvm_usize)
|
||||
) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
|
||||
vec![
|
||||
("data", ctx.i8_type().ptr_type(AddressSpace::default()).into()),
|
||||
("itemsize", llvm_usize.into()),
|
||||
("ndims", llvm_usize.into()),
|
||||
("shape", llvm_usize.ptr_type(AddressSpace::default()).into()),
|
||||
("strides", llvm_usize.ptr_type(AddressSpace::default()).into()),
|
||||
]
|
||||
}
|
||||
|
||||
// TODO: Move this into e.g. StructProxyType
|
||||
#[must_use]
|
||||
pub fn get_fields(
|
||||
pub fn get_layout(
|
||||
&self,
|
||||
ctx: &'ctx Context,
|
||||
llvm_usize: IntType<'ctx>,
|
||||
) -> NDArrayStructFields<'ctx> {
|
||||
Self::fields(ctx, llvm_usize)
|
||||
) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
|
||||
Self::layout(ctx, llvm_usize)
|
||||
}
|
||||
|
||||
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||
|
@ -206,10 +164,8 @@ impl<'ctx> NDArrayType<'ctx> {
|
|||
// * ndims : Number of dimensions in the array
|
||||
// * shape : Pointer to an array containing the shape of the NDArray
|
||||
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
|
||||
let field_tys = Self::fields(ctx, llvm_usize)
|
||||
.into_iter()
|
||||
.map(|field| field.1)
|
||||
.collect_vec();
|
||||
let field_tys =
|
||||
Self::layout(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||
|
||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||
}
|
||||
|
|
|
@ -1,194 +0,0 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use inkwell::{
|
||||
context::Context,
|
||||
types::{BasicTypeEnum, IntType},
|
||||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
|
||||
};
|
||||
|
||||
use crate::codegen::CodeGenContext;
|
||||
|
||||
/// Trait indicating that the structure is a field-wise representation of an LLVM structure.
|
||||
///
|
||||
/// # Usage
|
||||
///
|
||||
/// For example, for a simple C-slice LLVM structure:
|
||||
///
|
||||
/// ```ignore
|
||||
/// struct CSliceFields<'ctx> {
|
||||
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
|
||||
/// len: StructField<'ctx, IntValue<'ctx>>
|
||||
/// }
|
||||
/// ```
|
||||
pub trait StructFields<'ctx>: Eq + Copy {
|
||||
/// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types.
|
||||
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self;
|
||||
|
||||
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
|
||||
/// the type definition.
|
||||
#[must_use]
|
||||
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>;
|
||||
|
||||
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
|
||||
/// in the type definition.
|
||||
#[must_use]
|
||||
fn iter(&self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)> {
|
||||
self.to_vec().into_iter()
|
||||
}
|
||||
|
||||
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
|
||||
/// the type definition.
|
||||
#[must_use]
|
||||
fn into_vec(self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.to_vec()
|
||||
}
|
||||
|
||||
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
|
||||
/// in the type definition.
|
||||
#[must_use]
|
||||
fn into_iter(self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.into_vec().into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
/// A single field of an LLVM structure.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub struct StructField<'ctx, Value>
|
||||
where
|
||||
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
||||
{
|
||||
/// The index of this field within the structure.
|
||||
index: u32,
|
||||
|
||||
/// The name of this field.
|
||||
name: &'static str,
|
||||
|
||||
/// The type of this field.
|
||||
ty: BasicTypeEnum<'ctx>,
|
||||
|
||||
/// Instance of [`PhantomData`] containing [`Value`], used to implement automatic downcasts.
|
||||
_value_ty: PhantomData<Value>,
|
||||
}
|
||||
|
||||
impl<'ctx, Value> StructField<'ctx, Value>
|
||||
where
|
||||
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
||||
{
|
||||
/// Creates an instance of [`StructField`].
|
||||
///
|
||||
/// * `idx_counter` - The instance of [`FieldIndexCounter`] used to track the current field
|
||||
/// index.
|
||||
/// * `name` - Name of the field.
|
||||
/// * `ty` - The type of this field.
|
||||
pub(super) fn create(
|
||||
idx_counter: &mut FieldIndexCounter,
|
||||
name: &'static str,
|
||||
ty: impl Into<BasicTypeEnum<'ctx>>,
|
||||
) -> Self {
|
||||
StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData }
|
||||
}
|
||||
|
||||
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
|
||||
/// {idx...}, i32 {self.index}`.
|
||||
pub fn ptr_by_array_gep(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pobj: PointerValue<'ctx>,
|
||||
idx: &[IntValue<'ctx>],
|
||||
) -> PointerValue<'ctx> {
|
||||
unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
pobj,
|
||||
&[idx, &[ctx.ctx.i32_type().const_int(u64::from(self.index), false)]].concat(),
|
||||
"",
|
||||
)
|
||||
}
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Creates a pointer to this field in an arbitrary structure by performing the equivalent of
|
||||
/// `getelementptr i32 0, i32 {self.index}`.
|
||||
pub fn ptr_by_gep(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pobj: PointerValue<'ctx>,
|
||||
obj_name: Option<&'ctx str>,
|
||||
) -> PointerValue<'ctx> {
|
||||
ctx.builder
|
||||
.build_struct_gep(
|
||||
pobj,
|
||||
self.index,
|
||||
&obj_name.map(|name| format!("{name}.{}.addr", self.name)).unwrap_or_default(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Gets the value of this field for a given `obj`.
|
||||
#[must_use]
|
||||
pub fn get_from_value(&self, obj: StructValue<'ctx>) -> Value {
|
||||
obj.get_field_at_index(self.index).and_then(|value| Value::try_from(value).ok()).unwrap()
|
||||
}
|
||||
|
||||
/// Sets the value of this field for a given `obj`.
|
||||
pub fn set_from_value(&self, obj: StructValue<'ctx>, value: Value) {
|
||||
obj.set_field_at_index(self.index, value);
|
||||
}
|
||||
|
||||
/// Gets the value of this field for a pointer-to-structure.
|
||||
pub fn get(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pobj: PointerValue<'ctx>,
|
||||
obj_name: Option<&'ctx str>,
|
||||
) -> Value {
|
||||
ctx.builder
|
||||
.build_load(
|
||||
self.ptr_by_gep(ctx, pobj, obj_name),
|
||||
&obj_name.map(|name| format!("{name}.{}", self.name)).unwrap_or_default(),
|
||||
)
|
||||
.map_err(|_| ())
|
||||
.and_then(|value| Value::try_from(value))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Sets the value of this field for a pointer-to-structure.
|
||||
pub fn set(
|
||||
&self,
|
||||
ctx: &CodeGenContext<'ctx, '_>,
|
||||
pobj: PointerValue<'ctx>,
|
||||
value: Value,
|
||||
obj_name: Option<&'ctx str>,
|
||||
) {
|
||||
ctx.builder.build_store(self.ptr_by_gep(ctx, pobj, obj_name), value).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, Value> From<StructField<'ctx, Value>> for (&'static str, BasicTypeEnum<'ctx>)
|
||||
where
|
||||
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
||||
{
|
||||
fn from(value: StructField<'ctx, Value>) -> Self {
|
||||
(value.name, value.ty)
|
||||
}
|
||||
}
|
||||
|
||||
/// A counter that tracks the next index of a field using a monotonically increasing counter.
|
||||
#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub(super) struct FieldIndexCounter(u32);
|
||||
|
||||
impl FieldIndexCounter {
|
||||
/// Increments the number stored by this counter, returning the previous value.
|
||||
///
|
||||
/// Functionally equivalent to `i++` in C-based languages.
|
||||
pub fn increment(&mut self) -> u32 {
|
||||
let v = self.0;
|
||||
self.0 += 1;
|
||||
v
|
||||
}
|
||||
}
|
|
@ -13,7 +13,7 @@ use crate::codegen::{
|
|||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
||||
llvm_intrinsics::call_int_umin,
|
||||
stmt::gen_for_callback_incrementing,
|
||||
types::{structure::StructFields, NDArrayType},
|
||||
types::NDArrayType,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
|
@ -59,7 +59,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||
|
||||
let field_offset = self
|
||||
.get_type()
|
||||
.get_fields(ctx.ctx, self.llvm_usize)
|
||||
.get_layout(ctx.ctx, self.llvm_usize)
|
||||
.into_iter()
|
||||
.find_position(|field| field.0 == "data")
|
||||
.unwrap()
|
||||
|
@ -112,10 +112,26 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||
|
||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||
self.get_type()
|
||||
.get_fields(ctx.ctx, self.llvm_usize)
|
||||
.ndims
|
||||
.ptr_by_gep(ctx, self.as_base_value(), self.name)
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
|
||||
|
||||
let field_offset = self
|
||||
.get_type()
|
||||
.get_layout(ctx.ctx, self.llvm_usize)
|
||||
.into_iter()
|
||||
.find_position(|field| field.0 == "ndims")
|
||||
.unwrap()
|
||||
.0 as u64;
|
||||
|
||||
unsafe {
|
||||
ctx.builder
|
||||
.build_in_bounds_gep(
|
||||
self.as_base_value(),
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, false)],
|
||||
var_name.as_str(),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores the number of dimensions `ndims` into this instance.
|
||||
|
@ -144,7 +160,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||
|
||||
let field_offset = self
|
||||
.get_type()
|
||||
.get_fields(ctx.ctx, self.llvm_usize)
|
||||
.get_layout(ctx.ctx, self.llvm_usize)
|
||||
.into_iter()
|
||||
.find_position(|field| field.0 == "itemsize")
|
||||
.unwrap()
|
||||
|
@ -188,7 +204,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||
|
||||
let field_offset = self
|
||||
.get_type()
|
||||
.get_fields(ctx.ctx, self.llvm_usize)
|
||||
.get_layout(ctx.ctx, self.llvm_usize)
|
||||
.into_iter()
|
||||
.find_position(|field| field.0 == "shape")
|
||||
.unwrap()
|
||||
|
@ -234,7 +250,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
|||
|
||||
let field_offset = self
|
||||
.get_type()
|
||||
.get_fields(ctx.ctx, self.llvm_usize)
|
||||
.get_layout(ctx.ctx, self.llvm_usize)
|
||||
.into_iter()
|
||||
.find_position(|field| field.0 == "strides")
|
||||
.unwrap()
|
||||
|
|
Loading…
Reference in New Issue