forked from M-Labs/nac3
1
0
Fork 0

core: Refactor class abstractions

- Introduce new Type abstractions
- Rearrange some functions
This commit is contained in:
David Mak 2024-06-06 12:16:09 +08:00
parent 08129cc635
commit f0ab1b858a
4 changed files with 405 additions and 175 deletions

View File

@ -4,7 +4,7 @@ use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics, numpy};
use crate::codegen::classes::{NDArrayValue, UntypedArrayLikeAccessor};
use crate::codegen::classes::{NDArrayValue, ProxyValue, UntypedArrayLikeAccessor};
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
@ -93,7 +93,7 @@ pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, "int32", &[n_ty])
@ -164,7 +164,7 @@ pub fn call_int64<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, "int64", &[n_ty])
@ -251,7 +251,7 @@ pub fn call_uint32<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, "uint32", &[n_ty])
@ -332,7 +332,7 @@ pub fn call_uint64<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, "uint64", &[n_ty])
@ -397,7 +397,7 @@ pub fn call_float<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, "float", &[n_ty])
@ -443,7 +443,7 @@ pub fn call_round<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
@ -483,7 +483,7 @@ pub fn call_numpy_round<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
@ -552,7 +552,7 @@ pub fn call_bool<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
@ -602,7 +602,7 @@ pub fn call_floor<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
@ -652,7 +652,7 @@ pub fn call_ceil<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
@ -870,7 +870,7 @@ pub fn call_numpy_minimum<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| {
call_numpy_minimum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
)?.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -1088,7 +1088,7 @@ pub fn call_numpy_maximum<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| {
call_numpy_maximum(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
)?.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -1153,7 +1153,7 @@ pub fn call_abs<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[n_ty])
@ -1195,7 +1195,7 @@ pub fn call_numpy_isnan<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1237,7 +1237,7 @@ pub fn call_numpy_isinf<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1277,7 +1277,7 @@ pub fn call_numpy_sin<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1317,7 +1317,7 @@ pub fn call_numpy_cos<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1357,7 +1357,7 @@ pub fn call_numpy_exp<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1397,7 +1397,7 @@ pub fn call_numpy_exp2<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1437,7 +1437,7 @@ pub fn call_numpy_log<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1477,7 +1477,7 @@ pub fn call_numpy_log10<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1517,7 +1517,7 @@ pub fn call_numpy_log2<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1557,7 +1557,7 @@ pub fn call_numpy_fabs<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1597,7 +1597,7 @@ pub fn call_numpy_sqrt<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1637,7 +1637,7 @@ pub fn call_numpy_rint<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1677,7 +1677,7 @@ pub fn call_numpy_tan<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1717,7 +1717,7 @@ pub fn call_numpy_arcsin<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1757,7 +1757,7 @@ pub fn call_numpy_arccos<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1797,7 +1797,7 @@ pub fn call_numpy_arctan<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1837,7 +1837,7 @@ pub fn call_numpy_sinh<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1877,7 +1877,7 @@ pub fn call_numpy_cosh<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1917,7 +1917,7 @@ pub fn call_numpy_tanh<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1957,7 +1957,7 @@ pub fn call_numpy_arcsinh<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -1997,7 +1997,7 @@ pub fn call_numpy_arccosh<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2037,7 +2037,7 @@ pub fn call_numpy_arctanh<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2077,7 +2077,7 @@ pub fn call_numpy_expm1<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2117,7 +2117,7 @@ pub fn call_numpy_cbrt<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2157,7 +2157,7 @@ pub fn call_scipy_special_erf<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[z_ty])
@ -2197,7 +2197,7 @@ pub fn call_scipy_special_erfc<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2237,7 +2237,7 @@ pub fn call_scipy_special_gamma<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[z_ty])
@ -2277,7 +2277,7 @@ pub fn call_scipy_special_gammaln<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2317,7 +2317,7 @@ pub fn call_scipy_special_j0<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2357,7 +2357,7 @@ pub fn call_scipy_special_j1<'ctx, G: CodeGenerator + ?Sized>(
},
)?;
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x_ty])
@ -2424,7 +2424,7 @@ pub fn call_numpy_arctan2<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| {
call_numpy_arctan2(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
)?.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2491,7 +2491,7 @@ pub fn call_numpy_copysign<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| {
call_numpy_copysign(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
)?.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2558,7 +2558,7 @@ pub fn call_numpy_fmax<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| {
call_numpy_fmax(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
)?.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2625,7 +2625,7 @@ pub fn call_numpy_fmin<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| {
call_numpy_fmin(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
)?.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2681,7 +2681,7 @@ pub fn call_numpy_ldexp<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| {
call_numpy_ldexp(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
)?.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2748,7 +2748,7 @@ pub fn call_numpy_hypot<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| {
call_numpy_hypot(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
)?.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
@ -2815,7 +2815,7 @@ pub fn call_numpy_nextafter<'ctx, G: CodeGenerator + ?Sized>(
|generator, ctx, (lhs, rhs)| {
call_numpy_nextafter(generator, ctx, (x1_scalar_ty, lhs), (x2_scalar_ty, rhs))
},
)?.as_ptr_value().into()
)?.as_base_value().into()
}
_ => unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])

View File

@ -3,6 +3,8 @@ use inkwell::{
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
values::{BasicValueEnum, IntValue, PointerValue},
};
use inkwell::types::BasicType;
use inkwell::values::BasicValue;
use crate::codegen::{
CodeGenContext,
CodeGenerator,
@ -11,6 +13,40 @@ use crate::codegen::{
stmt::gen_for_callback_incrementing,
};
/// A LLVM type that is used to represent a non-primitive type in NAC3.
pub trait ProxyType<'ctx> {
/// The underlying type as represented by an LLVM type.
type Base: BasicType<'ctx>;
/// The type of values represented by this type.
type Value: ProxyValue<'ctx>;
/// Creates a [`value`][ProxyValue] with this as its type.
fn create_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value;
/// Returns the base type of this proxy.
fn as_base_type(&self) -> Self::Base;
}
/// A LLVM type that is used to represent a non-primitive value in NAC3.
pub trait ProxyValue<'ctx> {
/// The underlying type as represented by an LLVM value.
type Base: BasicValue<'ctx>;
/// The type of this value.
type Type: ProxyType<'ctx>;
/// Returns the [type][ProxyType] of this value.
fn get_type(&self) -> Self::Type;
/// Returns the base value of this proxy.
fn as_base_value(&self) -> Self::Base;
}
/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of
/// elements.
pub trait ArrayLikeValue<'ctx> {
@ -388,26 +424,20 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ArraySliceValue<'ctx> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ArraySliceValue<'ctx> {}
#[cfg(not(debug_assertions))]
pub fn assert_is_list<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {}
#[cfg(debug_assertions)]
pub fn assert_is_list<'ctx>(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) {
ListValue::is_instance(value, llvm_usize).unwrap();
/// Proxy type for a `list` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct ListType<'ctx> {
ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
}
/// Proxy type for accessing a `list` value in LLVM.
#[derive(Copy, Clone)]
pub struct ListValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>);
impl<'ctx> ListValue<'ctx> {
/// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(
value: PointerValue<'ctx>,
impl<'ctx> ListType<'ctx> {
/// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not.
pub fn is_type(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let llvm_list_ty = value.get_type().get_element_type();
let llvm_list_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else {
return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}"))
};
@ -433,28 +463,97 @@ impl<'ctx> ListValue<'ctx> {
Ok(())
}
/// Creates an [`ListType`] from a [`PointerType`].
#[must_use]
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok());
ListType { ty: ptr_ty, llvm_usize }
}
/// Returns the type of the `size` field of this `list` type.
#[must_use]
pub fn size_type(&self) -> IntType<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_int_type)
.unwrap()
}
/// Returns the element type of this `list` type.
#[must_use]
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(1)
.unwrap()
}
}
impl<'ctx> ProxyType<'ctx> for ListType<'ctx> {
type Base = PointerType<'ctx>;
type Value = ListValue<'ctx>;
fn create_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
ListValue { value, llvm_usize: self.llvm_usize, name }
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<ListType<'ctx>> for PointerType<'ctx> {
fn from(value: ListType<'ctx>) -> Self {
value.as_base_type()
}
}
/// Proxy type for accessing a `list` value in LLVM.
#[derive(Copy, Clone)]
pub struct ListValue<'ctx> {
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> ListValue<'ctx> {
/// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
ListType::is_type(value.get_type(), llvm_usize)
}
/// Creates an [`ListValue`] from a [`PointerValue`].
#[must_use]
pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self {
assert_is_list(ptr, llvm_usize);
ListValue(ptr, name)
}
debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
/// Returns the underlying [`PointerValue`] pointing to the `list` instance.
#[must_use]
pub fn as_ptr_value(&self) -> PointerValue<'ctx> {
self.0
<Self as ProxyValue<'ctx>>::Type::from_type(ptr.get_type(), llvm_usize)
.create_value(ptr, name)
}
/// Returns the double-indirection pointer to the `data` array, as if by calling `getelementptr`
/// on the field.
fn pptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.data.addr")).unwrap_or_default();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.as_ptr_value(),
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
).unwrap()
@ -464,11 +563,11 @@ impl<'ctx> ListValue<'ctx> {
/// Returns the pointer to the field storing the size of this `list`.
fn ptr_to_size(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.size.addr")).unwrap_or_default();
let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.0,
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
).unwrap()
@ -519,7 +618,7 @@ impl<'ctx> ListValue<'ctx> {
let psize = self.ptr_to_size(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.1.map(|v| format!("{v}.size")))
.or_else(|| self.name.map(|v| format!("{v}.size")))
.unwrap_or_default();
ctx.builder.build_load(psize, var_name.as_str())
@ -528,9 +627,22 @@ impl<'ctx> ListValue<'ctx> {
}
}
impl<'ctx> ProxyValue<'ctx> for ListValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = ListType<'ctx>;
fn get_type(&self) -> Self::Type {
ListType::from_type(self.as_base_value().get_type(), self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<ListValue<'ctx>> for PointerValue<'ctx> {
fn from(value: ListValue<'ctx>) -> Self {
value.as_ptr_value()
value.as_base_value()
}
}
@ -544,7 +656,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> {
_: &CodeGenContext<'ctx, '_>,
_: &G,
) -> AnyTypeEnum<'ctx> {
self.0.0.get_type().get_element_type()
self.0.value.get_type().get_element_type()
}
fn base_ptr<G: CodeGenerator + ?Sized>(
@ -552,7 +664,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.1.map(|v| format!("{v}.data")).unwrap_or_default();
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder.build_load(self.0.pptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
@ -616,22 +728,16 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
impl<'ctx> UntypedArrayLikeAccessor<'ctx> for ListDataProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx> for ListDataProxy<'ctx, '_> {}
#[cfg(not(debug_assertions))]
pub fn assert_is_range(_value: PointerValue) {}
#[cfg(debug_assertions)]
pub fn assert_is_range(value: PointerValue) {
RangeValue::is_instance(value).unwrap();
/// Proxy type for a `range` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct RangeType<'ctx> {
ty: PointerType<'ctx>,
}
/// Proxy type for accessing a `range` value in LLVM.
#[derive(Copy, Clone)]
pub struct RangeValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>);
impl<'ctx> RangeValue<'ctx> {
/// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance.
pub fn is_instance(value: PointerValue<'ctx>) -> Result<(), String> {
let llvm_range_ty = value.get_type().get_element_type();
impl<'ctx> RangeType<'ctx> {
/// Checks whether `llvm_ty` represents a `range` type, returning [Err] if it does not.
pub fn is_type(llvm_ty: PointerType<'ctx>) -> Result<(), String> {
let llvm_range_ty = llvm_ty.get_element_type();
let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else {
return Err(format!("Expected array type for `range` type, got {llvm_range_ty}"))
};
@ -651,37 +757,74 @@ impl<'ctx> RangeValue<'ctx> {
Ok(())
}
/// Creates an [`RangeValue`] from a [`PointerValue`].
/// Creates an [`RangeType`] from a [`PointerType`].
#[must_use]
pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self {
assert_is_range(ptr);
RangeValue(ptr, name)
pub fn from_type(ptr_ty: PointerType<'ctx>) -> Self {
debug_assert!(Self::is_type(ptr_ty).is_ok());
RangeType { ty: ptr_ty }
}
/// Returns the element type of this `range` object.
/// Returns the type of all fields of this `range` type.
#[must_use]
pub fn element_type(&self) -> IntType<'ctx> {
self.as_ptr_value()
.get_type()
pub fn value_type(&self) -> IntType<'ctx> {
self.as_base_type()
.get_element_type()
.into_array_type()
.get_element_type()
.into_int_type()
}
}
/// Returns the underlying [`PointerValue`] pointing to the `range` instance.
impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
type Base = PointerType<'ctx>;
type Value = RangeValue<'ctx>;
fn create_value(&self, value: <Self::Value as ProxyValue<'ctx>>::Base, name: Option<&'ctx str>) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
RangeValue { value, name }
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<RangeType<'ctx>> for PointerType<'ctx> {
fn from(value: RangeType<'ctx>) -> Self {
value.as_base_type()
}
}
/// Proxy type for accessing a `range` value in LLVM.
#[derive(Copy, Clone)]
pub struct RangeValue<'ctx> {
value: PointerValue<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> RangeValue<'ctx> {
/// Checks whether `value` is an instance of `range`, returning [Err] if `value` is not an instance.
pub fn is_instance(value: PointerValue<'ctx>) -> Result<(), String> {
RangeType::is_type(value.get_type())
}
/// Creates an [`RangeValue`] from a [`PointerValue`].
#[must_use]
pub fn as_ptr_value(&self) -> PointerValue<'ctx> {
self.0
pub fn from_ptr_val(ptr: PointerValue<'ctx>, name: Option<&'ctx str>) -> Self {
debug_assert!(Self::is_instance(ptr).is_ok());
<Self as ProxyValue<'ctx>>::Type::from_type(ptr.get_type()).create_value(ptr, name)
}
fn ptr_to_start(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.start.addr")).unwrap_or_default();
let var_name = self.name.map(|v| format!("{v}.start.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.0,
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(0, false)],
var_name.as_str(),
).unwrap()
@ -690,11 +833,11 @@ impl<'ctx> RangeValue<'ctx> {
fn ptr_to_end(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.end.addr")).unwrap_or_default();
let var_name = self.name.map(|v| format!("{v}.end.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.0,
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
var_name.as_str(),
).unwrap()
@ -703,11 +846,11 @@ impl<'ctx> RangeValue<'ctx> {
fn ptr_to_step(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.step.addr")).unwrap_or_default();
let var_name = self.name.map(|v| format!("{v}.step.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.0,
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, false)],
var_name.as_str(),
).unwrap()
@ -731,7 +874,7 @@ impl<'ctx> RangeValue<'ctx> {
let pstart = self.ptr_to_start(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.1.map(|v| format!("{v}.start")))
.or_else(|| self.name.map(|v| format!("{v}.start")))
.unwrap_or_default();
ctx.builder.build_load(pstart, var_name.as_str())
@ -756,7 +899,7 @@ impl<'ctx> RangeValue<'ctx> {
let pend = self.ptr_to_end(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.1.map(|v| format!("{v}.end")))
.or_else(|| self.name.map(|v| format!("{v}.end")))
.unwrap_or_default();
ctx.builder.build_load(pend, var_name.as_str())
@ -781,7 +924,7 @@ impl<'ctx> RangeValue<'ctx> {
let pstep = self.ptr_to_step(ctx);
let var_name = name
.map(ToString::to_string)
.or_else(|| self.1.map(|v| format!("{v}.step")))
.or_else(|| self.name.map(|v| format!("{v}.step")))
.unwrap_or_default();
ctx.builder.build_load(pstep, var_name.as_str())
@ -790,32 +933,39 @@ impl<'ctx> RangeValue<'ctx> {
}
}
impl<'ctx> From<RangeValue<'ctx>> for PointerValue<'ctx> {
fn from(value: RangeValue<'ctx>) -> Self {
value.as_ptr_value()
impl<'ctx> ProxyValue<'ctx> for RangeValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = RangeType<'ctx>;
fn get_type(&self) -> Self::Type {
RangeType::from_type(self.value.get_type())
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
#[cfg(not(debug_assertions))]
pub fn assert_is_ndarray<'ctx>(_value: PointerValue<'ctx>, _llvm_usize: IntType<'ctx>) {}
#[cfg(debug_assertions)]
pub fn assert_is_ndarray<'ctx>(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) {
NDArrayValue::is_instance(value, llvm_usize).unwrap();
impl<'ctx> From<RangeValue<'ctx>> for PointerValue<'ctx> {
fn from(value: RangeValue<'ctx>) -> Self {
value.as_base_value()
}
}
/// Proxy type for accessing an `NDArray` value in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayValue<'ctx>(PointerValue<'ctx>, Option<&'ctx str>);
/// Proxy type for a `ndarray` type in LLVM.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NDArrayType<'ctx> {
ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
}
impl<'ctx> NDArrayValue<'ctx> {
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(
value: PointerValue<'ctx>,
impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_type(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
let llvm_ndarray_ty = value.get_type().get_element_type();
let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"))
};
@ -855,31 +1005,96 @@ impl<'ctx> NDArrayValue<'ctx> {
Ok(())
}
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
/// Creates an [`NDArrayType`] from a [`PointerType`].
#[must_use]
pub fn from_ptr_val(
ptr: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
assert_is_ndarray(ptr, llvm_usize);
NDArrayValue(ptr, name)
pub fn from_type(ptr_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Self {
debug_assert!(Self::is_type(ptr_ty, llvm_usize).is_ok());
NDArrayType { ty: ptr_ty, llvm_usize }
}
/// Returns the underlying [`PointerValue`] pointing to the `NDArray` instance.
/// Returns the type of the `size` field of this `ndarray` type.
#[must_use]
pub fn as_ptr_value(&self) -> PointerValue<'ctx> {
self.0
pub fn size_type(&self) -> IntType<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_int_type)
.unwrap()
}
/// Returns the element type of this `ndarray` type.
#[must_use]
pub fn element_type(&self) -> BasicTypeEnum<'ctx> {
self.as_base_type()
.get_element_type()
.into_struct_type()
.get_field_type_at_index(2)
.unwrap()
}
}
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
type Base = PointerType<'ctx>;
type Value = NDArrayValue<'ctx>;
fn create_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
NDArrayValue { value, llvm_usize: self.llvm_usize, name }
}
fn as_base_type(&self) -> Self::Base {
self.ty
}
}
impl<'ctx> From<NDArrayType<'ctx>> for PointerType<'ctx> {
fn from(value: NDArrayType<'ctx>) -> Self {
value.as_base_type()
}
}
/// Proxy type for accessing an `NDArray` value in LLVM.
#[derive(Copy, Clone)]
pub struct NDArrayValue<'ctx> {
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
}
impl<'ctx> NDArrayValue<'ctx> {
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
NDArrayType::is_type(value.get_type(), llvm_usize)
}
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self {
debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
<Self as ProxyValue<'ctx>>::Type::from_type(ptr.get_type(), llvm_usize)
.create_value(ptr, name)
}
/// Returns the pointer to the field storing the number of dimensions of this `NDArray`.
fn ptr_to_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.0,
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
).unwrap()
@ -911,11 +1126,11 @@ impl<'ctx> NDArrayValue<'ctx> {
/// on the field.
fn ptr_to_dims(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.as_ptr_value(),
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
).unwrap()
@ -947,11 +1162,11 @@ impl<'ctx> NDArrayValue<'ctx> {
/// on the field.
fn ptr_to_data(&self, ctx: &CodeGenContext<'ctx, '_>) -> PointerValue<'ctx> {
let llvm_i32 = ctx.ctx.i32_type();
let var_name = self.1.map(|v| format!("{v}.data.addr")).unwrap_or_default();
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.as_ptr_value(),
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
var_name.as_str(),
).unwrap()
@ -981,9 +1196,22 @@ impl<'ctx> NDArrayValue<'ctx> {
}
}
impl<'ctx> ProxyValue<'ctx> for NDArrayValue<'ctx> {
type Base = PointerValue<'ctx>;
type Type = NDArrayType<'ctx>;
fn get_type(&self) -> Self::Type {
NDArrayType::from_type(self.as_base_value().get_type(), self.llvm_usize)
}
fn as_base_value(&self) -> Self::Base {
self.value
}
}
impl<'ctx> From<NDArrayValue<'ctx>> for PointerValue<'ctx> {
fn from(value: NDArrayValue<'ctx>) -> Self {
value.as_ptr_value()
value.as_base_value()
}
}
@ -1005,7 +1233,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.1.map(|v| format!("{v}.data")).unwrap_or_default();
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder.build_load(self.0.ptr_to_dims(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
@ -1110,7 +1338,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx: &CodeGenContext<'ctx, '_>,
_: &G,
) -> PointerValue<'ctx> {
let var_name = self.0.1.map(|v| format!("{v}.data")).unwrap_or_default();
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)

View File

@ -8,6 +8,7 @@ use crate::{
ArraySliceValue,
ListValue,
NDArrayValue,
ProxyValue,
RangeValue,
TypedArrayLikeAccessor,
UntypedArrayLikeAccessor,
@ -1090,7 +1091,7 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
emit_cont_bb(ctx, generator, list);
Ok(Some(list.as_ptr_value().into()))
Ok(Some(list.as_base_value().into()))
}
/// Generates LLVM IR for a binary operator expression using the [`Type`] and
@ -1173,8 +1174,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
ctx,
ndarray_dtype1,
if is_aug_assign { Some(left_val) } else { None },
(left_val.as_ptr_value().into(), false),
(right_val.as_ptr_value().into(), false),
(left_val.as_base_value().into(), false),
(right_val.as_base_value().into(), false),
|generator, ctx, (lhs, rhs)| {
gen_binop_expr_with_values(
generator,
@ -1189,7 +1190,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
)?
};
Ok(Some(res.as_ptr_value().into()))
Ok(Some(res.as_base_value().into()))
} else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
&mut ctx.unifier,
@ -1220,7 +1221,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
},
)?;
Ok(Some(res.as_ptr_value().into()))
Ok(Some(res.as_base_value().into()))
}
} else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
@ -1410,7 +1411,7 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
},
)?;
res.as_ptr_value().into()
res.as_base_value().into()
} else {
unimplemented!()
}))
@ -1478,7 +1479,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
ctx,
ctx.primitives.bool,
None,
(left_val.as_ptr_value().into(), false),
(left_val.as_base_value().into(), false),
(rhs, false),
|generator, ctx, (lhs, rhs)| {
let val = gen_cmpop_expr_with_values(
@ -1493,7 +1494,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
},
)?;
Ok(Some(res.as_ptr_value().into()))
Ok(Some(res.as_base_value().into()))
} else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
&mut ctx.unifier,
@ -1519,7 +1520,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
},
)?;
Ok(Some(res.as_ptr_value().into()))
Ok(Some(res.as_base_value().into()))
}
}
}
@ -1819,7 +1820,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ty,
v,
&slices,
)?.as_ptr_value().into()
)?.as_base_value().into()
}
ExprKind::Slice { .. } => {
@ -1833,7 +1834,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
ty,
v,
&[slice],
)?.as_ptr_value().into()
)?.as_base_value().into()
}
_ => {
@ -1935,7 +1936,7 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
llvm_i1.const_zero(),
);
ndarray.as_ptr_value().into()
ndarray.as_base_value().into()
}
}))
}
@ -2025,7 +2026,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
.ptr_offset(ctx, generator, &usize.const_int(i as u64, false), Some("elem_ptr"));
ctx.builder.build_store(elem_ptr, *v).unwrap();
}
arr_str_ptr.as_ptr_value().into()
arr_str_ptr.as_base_value().into()
}
ExprKind::Tuple { elts, .. } => {
let elements_val = elts
@ -2406,7 +2407,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
v,
(start, end, step),
);
res_array_ret.as_ptr_value().into()
res_array_ret.as_base_value().into()
} else {
let len = v.load_size(ctx, Some("len"));
let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? {

View File

@ -7,6 +7,7 @@ use crate::{
ArrayLikeValue,
ListValue,
NDArrayValue,
ProxyValue,
TypedArrayLikeAccessor,
TypedArrayLikeAdapter,
TypedArrayLikeMutator,
@ -1172,7 +1173,7 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
);
}
let lhs = if res.is_some_and(|res| res.as_ptr_value() == lhs.as_ptr_value()) {
let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) {
ndarray_copy_impl(generator, ctx, elem_ty, lhs)?
} else {
lhs