Compare commits
8 Commits
2b3e26c421
...
4764e60011
Author | SHA1 | Date |
---|---|---|
David Mak | 4764e60011 | |
David Mak | 04cb60564b | |
David Mak | a4fe4c5ca8 | |
David Mak | b230ab5ee5 | |
David Mak | 6e78513c43 | |
David Mak | 4573490647 | |
lyken | c4f09a49c7 | |
lyken | 94f11c37c3 |
|
@ -11,7 +11,6 @@ pub use range::*;
|
||||||
mod list;
|
mod list;
|
||||||
mod ndarray;
|
mod ndarray;
|
||||||
mod range;
|
mod range;
|
||||||
pub mod structure;
|
|
||||||
|
|
||||||
/// A LLVM type that is used to represent a corresponding type in NAC3.
|
/// A LLVM type that is used to represent a corresponding type in NAC3.
|
||||||
pub trait ProxyType<'ctx>: Into<Self::Base> {
|
pub trait ProxyType<'ctx>: Into<Self::Base> {
|
||||||
|
|
|
@ -1,15 +1,12 @@
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::Context,
|
||||||
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
types::{AnyTypeEnum, BasicType, BasicTypeEnum, IntType, PointerType},
|
||||||
values::{IntValue, PointerValue},
|
values::IntValue,
|
||||||
AddressSpace,
|
AddressSpace,
|
||||||
};
|
};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use super::{
|
use super::ProxyType;
|
||||||
structure::{FieldIndexCounter, StructField, StructFields},
|
|
||||||
ProxyType,
|
|
||||||
};
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
values::{ArraySliceValue, NDArrayValue, ProxyValue, TypedArrayLikeMutator},
|
values::{ArraySliceValue, NDArrayValue, ProxyValue, TypedArrayLikeMutator},
|
||||||
|
@ -29,51 +26,6 @@ pub struct NDArrayType<'ctx> {
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
|
||||||
pub struct NDArrayStructFields<'ctx> {
|
|
||||||
pub data: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
pub itemsize: StructField<'ctx, IntValue<'ctx>>,
|
|
||||||
pub ndims: StructField<'ctx, IntValue<'ctx>>,
|
|
||||||
pub shape: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
pub strides: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> StructFields<'ctx> for NDArrayStructFields<'ctx> {
|
|
||||||
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self {
|
|
||||||
let mut counter = FieldIndexCounter::default();
|
|
||||||
|
|
||||||
NDArrayStructFields {
|
|
||||||
data: StructField::create(
|
|
||||||
&mut counter,
|
|
||||||
"data",
|
|
||||||
ctx.i8_type().ptr_type(AddressSpace::default()),
|
|
||||||
),
|
|
||||||
itemsize: StructField::create(&mut counter, "itemsize", llvm_usize),
|
|
||||||
ndims: StructField::create(&mut counter, "ndims", llvm_usize),
|
|
||||||
shape: StructField::create(
|
|
||||||
&mut counter,
|
|
||||||
"shape",
|
|
||||||
llvm_usize.ptr_type(AddressSpace::default()),
|
|
||||||
),
|
|
||||||
strides: StructField::create(
|
|
||||||
&mut counter,
|
|
||||||
"strides",
|
|
||||||
llvm_usize.ptr_type(AddressSpace::default()),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
|
|
||||||
vec![
|
|
||||||
self.data.into(),
|
|
||||||
self.itemsize.into(),
|
|
||||||
self.ndims.into(),
|
|
||||||
self.shape.into(),
|
|
||||||
self.strides.into(),
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayType<'ctx> {
|
impl<'ctx> NDArrayType<'ctx> {
|
||||||
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
|
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
|
||||||
pub fn is_representable(
|
pub fn is_representable(
|
||||||
|
@ -179,21 +131,27 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
|
|
||||||
// TODO: Move this into e.g. StructProxyType
|
// TODO: Move this into e.g. StructProxyType
|
||||||
#[must_use]
|
#[must_use]
|
||||||
fn fields(
|
fn layout(
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
) -> NDArrayStructFields<'ctx> {
|
) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
|
||||||
NDArrayStructFields::new(ctx, llvm_usize)
|
vec![
|
||||||
|
("data", ctx.i8_type().ptr_type(AddressSpace::default()).into()),
|
||||||
|
("itemsize", llvm_usize.into()),
|
||||||
|
("ndims", llvm_usize.into()),
|
||||||
|
("shape", llvm_usize.ptr_type(AddressSpace::default()).into()),
|
||||||
|
("strides", llvm_usize.ptr_type(AddressSpace::default()).into()),
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Move this into e.g. StructProxyType
|
// TODO: Move this into e.g. StructProxyType
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn get_fields(
|
pub fn get_layout(
|
||||||
&self,
|
&self,
|
||||||
ctx: &'ctx Context,
|
ctx: &'ctx Context,
|
||||||
llvm_usize: IntType<'ctx>,
|
llvm_usize: IntType<'ctx>,
|
||||||
) -> NDArrayStructFields<'ctx> {
|
) -> Vec<(&'static str, BasicTypeEnum<'ctx>)> {
|
||||||
Self::fields(ctx, llvm_usize)
|
Self::layout(ctx, llvm_usize)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
/// Creates an LLVM type corresponding to the expected structure of an `NDArray`.
|
||||||
|
@ -206,10 +164,8 @@ impl<'ctx> NDArrayType<'ctx> {
|
||||||
// * ndims : Number of dimensions in the array
|
// * ndims : Number of dimensions in the array
|
||||||
// * shape : Pointer to an array containing the shape of the NDArray
|
// * shape : Pointer to an array containing the shape of the NDArray
|
||||||
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
|
// * strides : Pointer to an array indicating the number of bytes between each element at a dimension
|
||||||
let field_tys = Self::fields(ctx, llvm_usize)
|
let field_tys =
|
||||||
.into_iter()
|
Self::layout(ctx, llvm_usize).into_iter().map(|field| field.1).collect_vec();
|
||||||
.map(|field| field.1)
|
|
||||||
.collect_vec();
|
|
||||||
|
|
||||||
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
ctx.struct_type(&field_tys, false).ptr_type(AddressSpace::default())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,194 +0,0 @@
|
||||||
use std::marker::PhantomData;
|
|
||||||
|
|
||||||
use inkwell::{
|
|
||||||
context::Context,
|
|
||||||
types::{BasicTypeEnum, IntType},
|
|
||||||
values::{BasicValue, BasicValueEnum, IntValue, PointerValue, StructValue},
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::codegen::CodeGenContext;
|
|
||||||
|
|
||||||
/// Trait indicating that the structure is a field-wise representation of an LLVM structure.
|
|
||||||
///
|
|
||||||
/// # Usage
|
|
||||||
///
|
|
||||||
/// For example, for a simple C-slice LLVM structure:
|
|
||||||
///
|
|
||||||
/// ```ignore
|
|
||||||
/// struct CSliceFields<'ctx> {
|
|
||||||
/// ptr: StructField<'ctx, PointerValue<'ctx>>,
|
|
||||||
/// len: StructField<'ctx, IntValue<'ctx>>
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
pub trait StructFields<'ctx>: Eq + Copy {
|
|
||||||
/// Creates an instance of [`StructFields`] using the given `ctx` and `size_t` types.
|
|
||||||
fn new(ctx: &'ctx Context, llvm_usize: IntType<'ctx>) -> Self;
|
|
||||||
|
|
||||||
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
|
|
||||||
/// the type definition.
|
|
||||||
#[must_use]
|
|
||||||
fn to_vec(&self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>;
|
|
||||||
|
|
||||||
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
|
|
||||||
/// in the type definition.
|
|
||||||
#[must_use]
|
|
||||||
fn iter(&self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)> {
|
|
||||||
self.to_vec().into_iter()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a [`Vec`] that contains the fields of the structure in the order as they appear in
|
|
||||||
/// the type definition.
|
|
||||||
#[must_use]
|
|
||||||
fn into_vec(self) -> Vec<(&'static str, BasicTypeEnum<'ctx>)>
|
|
||||||
where
|
|
||||||
Self: Sized,
|
|
||||||
{
|
|
||||||
self.to_vec()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a [`Iterator`] that contains the fields of the structure in the order as they appear
|
|
||||||
/// in the type definition.
|
|
||||||
#[must_use]
|
|
||||||
fn into_iter(self) -> impl Iterator<Item = (&'static str, BasicTypeEnum<'ctx>)>
|
|
||||||
where
|
|
||||||
Self: Sized,
|
|
||||||
{
|
|
||||||
self.into_vec().into_iter()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A single field of an LLVM structure.
|
|
||||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
|
||||||
pub struct StructField<'ctx, Value>
|
|
||||||
where
|
|
||||||
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
|
||||||
{
|
|
||||||
/// The index of this field within the structure.
|
|
||||||
index: u32,
|
|
||||||
|
|
||||||
/// The name of this field.
|
|
||||||
name: &'static str,
|
|
||||||
|
|
||||||
/// The type of this field.
|
|
||||||
ty: BasicTypeEnum<'ctx>,
|
|
||||||
|
|
||||||
/// Instance of [`PhantomData`] containing [`Value`], used to implement automatic downcasts.
|
|
||||||
_value_ty: PhantomData<Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, Value> StructField<'ctx, Value>
|
|
||||||
where
|
|
||||||
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
|
||||||
{
|
|
||||||
/// Creates an instance of [`StructField`].
|
|
||||||
///
|
|
||||||
/// * `idx_counter` - The instance of [`FieldIndexCounter`] used to track the current field
|
|
||||||
/// index.
|
|
||||||
/// * `name` - Name of the field.
|
|
||||||
/// * `ty` - The type of this field.
|
|
||||||
pub(super) fn create(
|
|
||||||
idx_counter: &mut FieldIndexCounter,
|
|
||||||
name: &'static str,
|
|
||||||
ty: impl Into<BasicTypeEnum<'ctx>>,
|
|
||||||
) -> Self {
|
|
||||||
StructField { index: idx_counter.increment(), name, ty: ty.into(), _value_ty: PhantomData }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a pointer to this field in an arbitrary structure by performing a `getelementptr i32
|
|
||||||
/// {idx...}, i32 {self.index}`.
|
|
||||||
pub fn ptr_by_array_gep(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
pobj: PointerValue<'ctx>,
|
|
||||||
idx: &[IntValue<'ctx>],
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
unsafe {
|
|
||||||
ctx.builder.build_in_bounds_gep(
|
|
||||||
pobj,
|
|
||||||
&[idx, &[ctx.ctx.i32_type().const_int(u64::from(self.index), false)]].concat(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a pointer to this field in an arbitrary structure by performing the equivalent of
|
|
||||||
/// `getelementptr i32 0, i32 {self.index}`.
|
|
||||||
pub fn ptr_by_gep(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
pobj: PointerValue<'ctx>,
|
|
||||||
obj_name: Option<&'ctx str>,
|
|
||||||
) -> PointerValue<'ctx> {
|
|
||||||
ctx.builder
|
|
||||||
.build_struct_gep(
|
|
||||||
pobj,
|
|
||||||
self.index,
|
|
||||||
&obj_name.map(|name| format!("{name}.{}.addr", self.name)).unwrap_or_default(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Gets the value of this field for a given `obj`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn get_from_value(&self, obj: StructValue<'ctx>) -> Value {
|
|
||||||
obj.get_field_at_index(self.index).and_then(|value| Value::try_from(value).ok()).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sets the value of this field for a given `obj`.
|
|
||||||
pub fn set_from_value(&self, obj: StructValue<'ctx>, value: Value) {
|
|
||||||
obj.set_field_at_index(self.index, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Gets the value of this field for a pointer-to-structure.
|
|
||||||
pub fn get(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
pobj: PointerValue<'ctx>,
|
|
||||||
obj_name: Option<&'ctx str>,
|
|
||||||
) -> Value {
|
|
||||||
ctx.builder
|
|
||||||
.build_load(
|
|
||||||
self.ptr_by_gep(ctx, pobj, obj_name),
|
|
||||||
&obj_name.map(|name| format!("{name}.{}", self.name)).unwrap_or_default(),
|
|
||||||
)
|
|
||||||
.map_err(|_| ())
|
|
||||||
.and_then(|value| Value::try_from(value))
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Sets the value of this field for a pointer-to-structure.
|
|
||||||
pub fn set(
|
|
||||||
&self,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
pobj: PointerValue<'ctx>,
|
|
||||||
value: Value,
|
|
||||||
obj_name: Option<&'ctx str>,
|
|
||||||
) {
|
|
||||||
ctx.builder.build_store(self.ptr_by_gep(ctx, pobj, obj_name), value).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx, Value> From<StructField<'ctx, Value>> for (&'static str, BasicTypeEnum<'ctx>)
|
|
||||||
where
|
|
||||||
Value: BasicValue<'ctx> + TryFrom<BasicValueEnum<'ctx>, Error = ()>,
|
|
||||||
{
|
|
||||||
fn from(value: StructField<'ctx, Value>) -> Self {
|
|
||||||
(value.name, value.ty)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A counter that tracks the next index of a field using a monotonically increasing counter.
|
|
||||||
#[derive(Default, Debug, PartialEq, Eq, Clone, Copy)]
|
|
||||||
pub(super) struct FieldIndexCounter(u32);
|
|
||||||
|
|
||||||
impl FieldIndexCounter {
|
|
||||||
/// Increments the number stored by this counter, returning the previous value.
|
|
||||||
///
|
|
||||||
/// Functionally equivalent to `i++` in C-based languages.
|
|
||||||
pub fn increment(&mut self) -> u32 {
|
|
||||||
let v = self.0;
|
|
||||||
self.0 += 1;
|
|
||||||
v
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -13,7 +13,7 @@ use crate::codegen::{
|
||||||
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
|
||||||
llvm_intrinsics::call_int_umin,
|
llvm_intrinsics::call_int_umin,
|
||||||
stmt::gen_for_callback_incrementing,
|
stmt::gen_for_callback_incrementing,
|
||||||
types::{structure::StructFields, NDArrayType},
|
types::NDArrayType,
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
|
||||||
let field_offset = self
|
let field_offset = self
|
||||||
.get_type()
|
.get_type()
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
.get_layout(ctx.ctx, self.llvm_usize)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.find_position(|field| field.0 == "data")
|
.find_position(|field| field.0 == "data")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -112,10 +112,26 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
|
||||||
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
|
||||||
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
|
||||||
self.get_type()
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
|
||||||
.ndims
|
|
||||||
.ptr_by_gep(ctx, self.as_base_value(), self.name)
|
let field_offset = self
|
||||||
|
.get_type()
|
||||||
|
.get_layout(ctx.ctx, self.llvm_usize)
|
||||||
|
.into_iter()
|
||||||
|
.find_position(|field| field.0 == "ndims")
|
||||||
|
.unwrap()
|
||||||
|
.0 as u64;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
ctx.builder
|
||||||
|
.build_in_bounds_gep(
|
||||||
|
self.as_base_value(),
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(field_offset, false)],
|
||||||
|
var_name.as_str(),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stores the number of dimensions `ndims` into this instance.
|
/// Stores the number of dimensions `ndims` into this instance.
|
||||||
|
@ -144,7 +160,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
|
||||||
let field_offset = self
|
let field_offset = self
|
||||||
.get_type()
|
.get_type()
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
.get_layout(ctx.ctx, self.llvm_usize)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.find_position(|field| field.0 == "itemsize")
|
.find_position(|field| field.0 == "itemsize")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -188,7 +204,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
|
||||||
let field_offset = self
|
let field_offset = self
|
||||||
.get_type()
|
.get_type()
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
.get_layout(ctx.ctx, self.llvm_usize)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.find_position(|field| field.0 == "shape")
|
.find_position(|field| field.0 == "shape")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -234,7 +250,7 @@ impl<'ctx> NDArrayValue<'ctx> {
|
||||||
|
|
||||||
let field_offset = self
|
let field_offset = self
|
||||||
.get_type()
|
.get_type()
|
||||||
.get_fields(ctx.ctx, self.llvm_usize)
|
.get_layout(ctx.ctx, self.llvm_usize)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.find_position(|field| field.0 == "strides")
|
.find_position(|field| field.0 == "strides")
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|
Loading…
Reference in New Issue