forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: remove ScalarObject

This commit is contained in:
lyken 2024-08-15 13:34:48 +08:00
parent f8b934096d
commit 4b765cfb27
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
8 changed files with 627 additions and 526 deletions

View File

@ -1,18 +1,71 @@
use inkwell::values::BasicValueEnum;
use inkwell::{
values::{BasicValue, BasicValueEnum, FloatValue, IntValue},
FloatPredicate, IntPredicate,
};
use itertools::Itertools;
use list::ListObject;
use ndarray::NDArrayObject;
use ndarray::{NDArrayObject, NDArrayOut};
use range::RangeObject;
use tuple::TupleObject;
use crate::typecheck::typedef::{Type, TypeEnum};
use crate::{
toplevel::helper::PrimDef,
typecheck::typedef::{Type, TypeEnum},
};
use super::{model::*, CodeGenContext, CodeGenerator};
use super::{llvm_intrinsics, model::*, CodeGenContext, CodeGenerator};
pub mod list;
pub mod ndarray;
pub mod range;
pub mod tuple;
/// Convenience function to crash the program when types of arguments are not supported.
/// Used to be debugged with a stacktrace.
fn unsupported_type<I>(ctx: &CodeGenContext<'_, '_>, tys: I) -> !
where
I: IntoIterator<Item = Type>,
{
unreachable!(
"unsupported types found '{}'",
tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "),
)
}
#[derive(Debug, Clone, Copy)]
pub enum FloorOrCeil {
Floor,
Ceil,
}
#[derive(Debug, Clone, Copy)]
pub enum MinOrMax {
Min,
Max,
}
fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.int32, ctx.primitives.int64]
}
fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.uint32, ctx.primitives.uint64]
}
fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64]
}
fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
]
}
#[derive(Debug, Clone, Copy)]
pub struct AnyObject<'ctx> {
pub ty: Type,
@ -20,13 +73,260 @@ pub struct AnyObject<'ctx> {
}
impl<'ctx> AnyObject<'ctx> {
// Get the `len()` of this object.
pub fn len<G: CodeGenerator + ?Sized>(
/// Returns true if this object's type is a [`TypeEnum::TObj`] and has the object ID as `prim`.
pub fn is_obj(&self, ctx: &mut CodeGenContext<'ctx, '_>, prim: PrimDef) -> bool {
match &*ctx.unifier.get_ty(self.ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id == prim.id(),
_ => false,
}
}
/// Returns true if this object's type is a [`TypeEnum::TTuple`]
pub fn is_tuple(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
matches!(&*ctx.unifier.get_ty(self.ty), TypeEnum::TTuple { .. })
}
pub fn is_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.int32)
}
pub fn is_uint32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.uint32)
}
pub fn is_int64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.int64)
}
pub fn is_uint64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.uint64)
}
pub fn is_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.bool)
}
pub fn is_int_like(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, int_like(ctx))
}
pub fn is_signed_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, signed_ints(ctx))
}
pub fn is_unsigned_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, unsigned_ints(ctx))
}
/// Convenience function. If object has type `int32`, `int64`, `uint32`, `uint64`, or `bool`,
/// get its underlying LLVM value.
///
/// Panic if the type is wrong.
pub fn into_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
if self.is_int_like(ctx) {
self.value.into_int_value()
} else {
panic!("not an int32 type")
}
}
pub fn is_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
self.is_obj(ctx, PrimDef::Float)
}
/// Convenience function. If object has type `float`, get its underlying LLVM value.
///
/// Panic if the type is wrong.
pub fn into_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Float<'ctx, Float64> {
if self.is_float(ctx) {
// self.value must be a FloatValue
FloatModel(Float64).believe_value(self.value.into_float_value())
} else {
panic!("not a float type")
}
}
pub fn is_ndarray(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
self.is_obj(ctx, PrimDef::NDArray)
}
pub fn into_ndarray<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Int<'ctx, Int32> {
match &*ctx.unifier.get_ty_immutable(self.ty) {
) -> NDArrayObject<'ctx> {
NDArrayObject::from_object(generator, ctx, *self)
}
/// Helper function to compare two scalars.
///
/// Only int-to-int and float-to-float comparisons are allowed.
///
/// Panic otherwise.
pub fn compare_int_or_float_by_predicate<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: AnyObject<'ctx>,
rhs: AnyObject<'ctx>,
int_predicate: IntPredicate,
float_predicate: FloatPredicate,
name: &str,
) -> Int<'ctx, Bool> {
if !ctx.unifier.unioned(lhs.ty, rhs.ty) {
panic!("lhs and rhs type are not the same.")
}
let bool_model = IntModel(Bool);
let common_ty = lhs.ty;
let result = if lhs.is_float(ctx) {
let lhs = lhs.into_float(ctx);
let rhs = rhs.into_float(ctx);
ctx.builder.build_float_compare(float_predicate, lhs.value, rhs.value, name).unwrap()
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
let lhs = lhs.into_int(ctx);
let rhs = rhs.into_int(ctx);
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
} else {
unsupported_type(ctx, [lhs.ty, rhs.ty]);
};
bool_model.check_value(generator, ctx.ctx, result).unwrap()
}
/// Helper function for `int32()`, `int64()`, `uint32()`, and `uint64()`.
///
/// TODO: Document me
fn cast_to_int_conversion<'a, G, HandleFloatFn>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ret_int_ty: Type,
handle_float: HandleFloatFn,
) -> AnyObject<'ctx>
where
G: CodeGenerator + ?Sized,
HandleFloatFn:
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
{
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
let result = if self.is_float(ctx) {
// Handle float to int
let n = self.into_float(ctx);
handle_float(generator, ctx, n.value)
} else if self.is_int_like(ctx) {
// Handle int to a new int type
let n = self.into_int(ctx);
if n.get_type().get_bit_width() <= ret_int_ty_llvm.get_bit_width() {
ctx.builder.build_int_z_extend(n, ret_int_ty_llvm, "zext").unwrap()
} else {
ctx.builder.build_int_truncate(n, ret_int_ty_llvm, "trunc").unwrap()
}
} else {
unsupported_type(ctx, [self.ty]);
};
assert_eq!(ret_int_ty_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
AnyObject { value: result.into(), ty: ret_int_ty }
}
/// Call `int32()` on this object.
#[must_use]
pub fn call_int32<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
self.cast_to_int_conversion(
generator,
ctx,
ctx.primitives.int32,
|_generator, ctx, input| {
let n =
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap();
ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap()
},
)
}
/// Call `int64()` on this object.
#[must_use]
pub fn call_int64<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
self.cast_to_int_conversion(
generator,
ctx,
ctx.primitives.int64,
|_generator, ctx, input| {
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap()
},
)
}
/// Call `uint32()` on this object.
#[must_use]
pub fn call_uint32<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
self.cast_to_int_conversion(generator, ctx, ctx.primitives.uint32, |_generator, ctx, n| {
let n_gez = ctx
.builder
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
.unwrap();
let to_int32 =
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap();
let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder
.build_select(
n_gez,
ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(),
to_int32,
"conv",
)
.unwrap()
.into_int_value()
})
}
/// Call `uint64()` on this object.
#[must_use]
pub fn call_uint64<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
self.cast_to_int_conversion(generator, ctx, ctx.primitives.uint64, |_generator, ctx, n| {
let val_gez = ctx
.builder
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
.unwrap();
let to_int64 =
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap().into_int_value()
})
}
// Get the `len()` of this object.
pub fn call_len<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
// TODO: Switch to returning SizeT
let result = match &*ctx.unifier.get_ty_immutable(self.ty) {
TypeEnum::TTuple { .. } => {
let tuple = TupleObject::from_object(ctx, *self);
tuple.len(generator, ctx).truncate(generator, ctx, Int32, "tuple_len_32")
@ -50,6 +350,259 @@ impl<'ctx> AnyObject<'ctx> {
ndarray.len(generator, ctx).truncate(generator, ctx, Int32, "ndarray_len_i32")
}
_ => unreachable!(),
};
AnyObject { ty: ctx.primitives.int32, value: result.value.as_basic_value_enum() }
}
/// Like [`AnyObject::call_bool`] but this returns an `Int<'ctx, Bool>` instead of an object.
pub fn bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Int<'ctx, Bool> {
let bool_model = IntModel(Bool);
if self.is_int_like(ctx) {
let n = self.into_int(ctx);
let n = ctx
.builder
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
.unwrap();
bool_model.believe_value(n)
} else if self.is_float(ctx) {
let n = self.value.into_float_value();
let n = ctx
.builder
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
.unwrap();
bool_model.believe_value(n)
} else {
unsupported_type(ctx, [self.ty])
}
}
/// Call `bool()` on this object.
pub fn call_bool<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
let n = self.bool(ctx);
// NAC3 booleans are i8
let llvm_i8 = ctx.ctx.i8_type();
let n = ctx.builder.build_int_z_extend(n.value, llvm_i8, "bool").unwrap();
AnyObject { ty: ctx.primitives.bool, value: n.as_basic_value_enum() }
}
/// Call `float()` on this object.
pub fn call_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
let f64_model = FloatModel(Float64);
let llvm_f64 = ctx.ctx.f64_type();
let result = if self.is_float(ctx) {
self.into_float(ctx)
} else if self.is_signed_int(ctx) || self.is_bool(ctx) {
let n = self.into_int(ctx);
let n = ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap();
f64_model.believe_value(n)
} else if self.is_unsigned_int(ctx) {
let n = self.into_int(ctx);
let n = ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap();
f64_model.believe_value(n)
} else {
unsupported_type(ctx, [self.ty]);
};
AnyObject { ty: ctx.primitives.float, value: result.value.as_basic_value_enum() }
}
// Call `abs()` on this object.
#[must_use]
pub fn call_abs<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
if self.is_float(ctx) {
let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
AnyObject { value: n.into(), ty: ctx.primitives.float }
} else if self.is_unsigned_int(ctx) || self.is_signed_int(ctx) {
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
let n = self.value.into_int_value();
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
AnyObject { value: n.into(), ty: self.ty }
} else if self.is_ndarray(ctx) {
let ndarray = self.into_ndarray(generator, ctx);
ndarray
.map(
generator,
ctx,
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|generator, ctx, scalar| Ok(scalar.call_abs(generator, ctx)),
)
.unwrap()
.to_any_object(ctx)
} else {
unsupported_type(ctx, [self.ty])
}
}
// Call `round()` on this object.
//
// It is possible to specify which kind of int type to return.
pub fn call_round<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ret_int_ty: Type,
) -> AnyObject<'ctx> {
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
let result = if ctx.unifier.unioned(self.ty, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, "round").unwrap()
} else {
unsupported_type(ctx, [self.ty])
};
AnyObject { ty: ret_int_ty, value: result.as_basic_value_enum() }
}
/// Call `np_round()` on this object.
///
/// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type.
#[must_use]
pub fn call_np_round<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
if self.is_float(ctx) {
let n = self.into_float(ctx);
let n = llvm_intrinsics::call_float_rint(ctx, n.value, None);
AnyObject { ty: ctx.primitives.float, value: n.as_basic_value_enum() }
} else if self.is_ndarray(ctx) {
let ndarray = self.into_ndarray(generator, ctx);
ndarray
.map(
generator,
ctx,
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|generator, ctx, scalar| Ok(scalar.call_np_round(generator, ctx)),
)
.unwrap()
.to_any_object(ctx)
} else {
unsupported_type(ctx, [self.ty])
}
}
/// Call `min()` or `max()` on two objects.
pub fn call_min_or_max(
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
a: AnyObject<'ctx>,
b: AnyObject<'ctx>,
) -> AnyObject<'ctx> {
if !ctx.unifier.unioned(a.ty, b.ty) {
unsupported_type(ctx, [a.ty, b.ty])
}
let common_ty = a.ty;
if a.is_float(ctx) {
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_float_minnum,
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
};
let a = a.into_float(ctx).value;
let b = b.into_float(ctx).value;
let result = function(ctx, a, b, None);
AnyObject { value: result.as_basic_value_enum(), ty: ctx.primitives.float }
} else if a.is_unsigned_int(ctx) || a.is_bool(ctx) {
// Treating bool has an unsigned int since that is convenient
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_int_umin,
MinOrMax::Max => llvm_intrinsics::call_int_umax,
};
let a = a.into_int(ctx);
let b = b.into_int(ctx);
let result = function(ctx, a, b, None);
AnyObject { value: result.as_basic_value_enum(), ty: common_ty }
} else if a.is_signed_int(ctx) {
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_int_smin,
MinOrMax::Max => llvm_intrinsics::call_int_smax,
};
let a = a.into_int(ctx);
let b = b.into_int(ctx);
let result = function(ctx, a, b, None);
AnyObject { value: result.as_basic_value_enum(), ty: common_ty }
} else {
unsupported_type(ctx, [common_ty])
}
}
/// Call `floor()` or `ceil()` on this object.
///
/// It is possible to specify which kind of int type to return.
#[must_use]
pub fn call_floor_or_ceil<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: FloorOrCeil,
ret_int_ty: Type,
) -> Self {
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
if self.is_float(ctx) {
let function = match kind {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.into_float(ctx).value;
let n = function(ctx, n, None);
let n = ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, "").unwrap();
AnyObject { ty: ret_int_ty, value: n.as_basic_value_enum() }
} else {
unsupported_type(ctx, [self.ty])
}
}
/// Call `np_floor()` or `np_ceil()` on this object.
#[must_use]
pub fn call_np_floor_or_ceil<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: FloorOrCeil,
) -> Self {
// TODO:
if self.is_float(ctx) {
let function = match kind {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.into_float(ctx).value;
let n = function(ctx, n, None);
AnyObject { ty: ctx.primitives.float, value: n.as_basic_value_enum() }
} else if self.is_ndarray(ctx) {
let ndarray = self.into_ndarray(generator, ctx);
ndarray
.map(
generator,
ctx,
NDArrayOut::NewNDArray { dtype: ctx.primitives.float },
|generator, ctx, scalar| Ok(scalar.call_np_floor_or_ceil(generator, ctx, kind)),
)
.unwrap()
.to_any_object(ctx)
} else {
unsupported_type(ctx, [self.ty])
}
}
}

View File

@ -1,10 +1,10 @@
use inkwell::{values::BasicValueEnum, IntPredicate};
use super::{scalar::ScalarObject, NDArrayObject};
use super::NDArrayObject;
use crate::{
codegen::{
irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, CodeGenContext,
CodeGenerator,
irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, object::AnyObject,
CodeGenContext, CodeGenerator,
},
typecheck::typedef::Type,
};
@ -93,10 +93,10 @@ impl<'ctx> NDArrayObject<'ctx> {
dtype: Type,
ndims: u64,
shape: Ptr<'ctx, IntModel<SizeT>>,
fill_value: ScalarObject<'ctx>,
fill_value: AnyObject<'ctx>,
) -> Self {
// Sanity check on `fill_value`'s dtype.
assert!(ctx.unifier.unioned(dtype, fill_value.dtype));
assert!(ctx.unifier.unioned(dtype, fill_value.ty));
let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape);
ndarray.fill(generator, ctx, fill_value);
@ -112,7 +112,7 @@ impl<'ctx> NDArrayObject<'ctx> {
shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
let fill_value = ndarray_zero_value(generator, ctx, dtype);
let fill_value = ScalarObject { value: fill_value, dtype };
let fill_value = AnyObject { value: fill_value, ty: dtype };
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
}
@ -125,7 +125,7 @@ impl<'ctx> NDArrayObject<'ctx> {
shape: Ptr<'ctx, IntModel<SizeT>>,
) -> Self {
let fill_value = ndarray_one_value(generator, ctx, dtype);
let fill_value = ScalarObject { value: fill_value, dtype };
let fill_value = AnyObject { value: fill_value, ty: dtype };
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
}

View File

@ -1,443 +1,13 @@
use inkwell::{
values::{BasicValue, FloatValue, IntValue},
FloatPredicate, IntPredicate,
};
use itertools::Itertools;
use inkwell::{FloatPredicate, IntPredicate};
use crate::{
codegen::{llvm_intrinsics, model::*, stmt::gen_if_callback, CodeGenContext, CodeGenerator},
typecheck::typedef::Type,
use crate::codegen::{
model::*,
object::{AnyObject, MinOrMax},
stmt::gen_if_callback,
CodeGenContext, CodeGenerator,
};
use super::{scalar::ScalarObject, NDArrayObject};
/// Convenience function to crash the program when types of arguments are not supported.
/// Used to be debugged with a stacktrace.
fn unsupported_type<I>(ctx: &CodeGenContext<'_, '_>, tys: I) -> !
where
I: IntoIterator<Item = Type>,
{
unreachable!(
"unsupported types found '{}'",
tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "),
)
}
#[derive(Debug, Clone, Copy)]
pub enum FloorOrCeil {
Floor,
Ceil,
}
#[derive(Debug, Clone, Copy)]
pub enum MinOrMax {
Min,
Max,
}
fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.int32, ctx.primitives.int64]
}
fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.uint32, ctx.primitives.uint64]
}
fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64]
}
fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
vec![
ctx.primitives.bool,
ctx.primitives.int32,
ctx.primitives.int64,
ctx.primitives.uint32,
ctx.primitives.uint64,
]
}
fn cast_to_int_conversion<'ctx, 'a, G, HandleFloatFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
scalar: ScalarObject<'ctx>,
ret_int_dtype: Type,
handle_float: HandleFloatFn,
) -> ScalarObject<'ctx>
where
G: CodeGenerator + ?Sized,
HandleFloatFn:
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
{
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) {
// Special handling for floats
let n = scalar.value.into_float_value();
handle_float(generator, ctx, n)
} else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) {
let n = scalar.value.into_int_value();
if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() {
ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap()
} else {
ctx.builder.build_int_truncate(n, ret_int_dtype_llvm, "trunc").unwrap()
}
} else {
unsupported_type(ctx, [scalar.dtype]);
};
assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
ScalarObject { value: result.into(), dtype: ret_int_dtype }
}
impl<'ctx> ScalarObject<'ctx> {
/// Convenience function. Assume this scalar has typechecker type float64, get its underlying LLVM value.
///
/// Panic if the type is wrong.
pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> {
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
self.value.into_float_value() // self.value must be a FloatValue
} else {
panic!("not a float type")
}
}
/// Convenience function. Assume this scalar has typechecker type int32, get its underlying LLVM value.
///
/// Panic if the type is wrong.
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) {
let value = self.value.into_int_value();
debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check
value
} else {
panic!("not a float type")
}
}
/// Compare two scalars. Only int-to-int and float-to-float comparisons are allowed.
/// Panic otherwise.
pub fn compare<G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: ScalarObject<'ctx>,
rhs: ScalarObject<'ctx>,
int_predicate: IntPredicate,
float_predicate: FloatPredicate,
name: &str,
) -> Int<'ctx, Bool> {
if !ctx.unifier.unioned(lhs.dtype, rhs.dtype) {
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
}
let bool_model = IntModel(Bool);
let common_ty = lhs.dtype;
let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) {
let lhs = lhs.value.into_float_value();
let rhs = rhs.value.into_float_value();
ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap()
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
let lhs = lhs.value.into_int_value();
let rhs = rhs.value.into_int_value();
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
} else {
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
};
bool_model.check_value(generator, ctx.ctx, result).unwrap()
}
/// Invoke NAC3's builtin `int32()`.
#[must_use]
pub fn cast_to_int32<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
cast_to_int_conversion(
generator,
ctx,
*self,
ctx.primitives.int32,
|_generator, ctx, input| {
let n =
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap();
ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap()
},
)
}
/// Invoke NAC3's builtin `int64()`.
#[must_use]
pub fn cast_to_int64<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
cast_to_int_conversion(
generator,
ctx,
*self,
ctx.primitives.int64,
|_generator, ctx, input| {
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap()
},
)
}
/// Invoke NAC3's builtin `uint32()`.
#[must_use]
pub fn cast_to_uint32<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
cast_to_int_conversion(
generator,
ctx,
*self,
ctx.primitives.uint32,
|_generator, ctx, n| {
let n_gez = ctx
.builder
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
.unwrap();
let to_int32 =
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap();
let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder
.build_select(
n_gez,
ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(),
to_int32,
"conv",
)
.unwrap()
.into_int_value()
},
)
}
/// Invoke NAC3's builtin `uint64()`.
#[must_use]
pub fn cast_to_uint64<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
cast_to_int_conversion(
generator,
ctx,
*self,
ctx.primitives.uint64,
|_generator, ctx, n| {
let val_gez = ctx
.builder
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
.unwrap();
let to_int64 =
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
let to_uint64 =
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
ctx.builder
.build_select(val_gez, to_uint64, to_int64, "conv")
.unwrap()
.into_int_value()
},
)
}
/// Invoke NAC3's builtin `bool()`.
#[must_use]
pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
// TODO: Why is the original code being so lax about i1 and i8 for the returned int type?
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) {
self.value.into_int_value()
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
let n = self.value.into_int_value();
ctx.builder
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
.unwrap()
} else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
ctx.builder
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
.unwrap()
} else {
unsupported_type(ctx, [self.dtype])
};
ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `float()`.
#[must_use]
pub fn cast_to_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
let llvm_f64 = ctx.ctx.f64_type();
let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
self.value.into_float_value()
} else if ctx
.unifier
.unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat())
{
let n = self.value.into_int_value();
ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap()
} else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) {
let n = self.value.into_int_value();
ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap()
} else {
unsupported_type(ctx, [self.dtype]);
};
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
}
/// Invoke NAC3's builtin `round()`.
#[must_use]
pub fn round<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
ret_int_dtype: Type,
) -> Self {
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap()
} else {
unsupported_type(ctx, [self.dtype, ret_int_dtype])
};
ScalarObject { dtype: ret_int_dtype, value: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `np_round()`.
///
/// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type.
#[must_use]
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
llvm_intrinsics::call_float_rint(ctx, n, None)
} else {
unsupported_type(ctx, [self.dtype])
};
ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `min()` or `max()`.
pub fn min_or_max(
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
a: Self,
b: Self,
) -> Self {
if !ctx.unifier.unioned(a.dtype, b.dtype) {
unsupported_type(ctx, [a.dtype, b.dtype])
}
let common_dtype = a.dtype;
if ctx.unifier.unioned(common_dtype, ctx.primitives.float) {
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_float_minnum,
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
};
let result =
function(ctx, a.value.into_float_value(), b.value.into_float_value(), None);
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
} else if ctx.unifier.unioned_any(
common_dtype,
[unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(),
) {
// Treating bool has an unsigned int since that is convenient
let function = match kind {
MinOrMax::Min => llvm_intrinsics::call_int_umin,
MinOrMax::Max => llvm_intrinsics::call_int_umax,
};
let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None);
ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype }
} else {
unsupported_type(ctx, [common_dtype])
}
}
/// Invoke NAC3's builtin `floor()` or `ceil()`.
///
/// * `ret_int_dtype` - The type of int to return.
///
/// Takes in a float/int and returns an int of type `ret_int_dtype`
#[must_use]
pub fn floor_or_ceil<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: FloorOrCeil,
ret_int_dtype: Type,
) -> Self {
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let function = match kind {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.value.into_float_value();
let n = function(ctx, n, None);
let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap();
ScalarObject { dtype: ret_int_dtype, value: n.as_basic_value_enum() }
} else {
unsupported_type(ctx, [self.dtype])
}
}
/// Invoke NAC3's builtin `np_floor()`/ `np_ceil()`.
///
/// Takes in a float/int and returns a float64 result.
#[must_use]
pub fn np_floor_or_ceil(&self, ctx: &mut CodeGenContext<'ctx, '_>, kind: FloorOrCeil) -> Self {
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let function = match kind {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.value.into_float_value();
let n = function(ctx, n, None);
ScalarObject { dtype: ctx.primitives.float, value: n.as_basic_value_enum() }
} else {
unsupported_type(ctx, [self.dtype])
}
}
/// Invoke NAC3's builtin `abs()`.
#[must_use]
pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
ScalarObject { value: n.into(), dtype: ctx.primitives.float }
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
let n = self.value.into_int_value();
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
ScalarObject { value: n.into(), dtype: self.dtype }
} else {
unsupported_type(ctx, [self.dtype])
}
}
}
use super::NDArrayObject;
impl<'ctx> NDArrayObject<'ctx> {
/// Helper function to implement NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
@ -451,7 +21,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
on_empty_err_msg: &str,
) -> (ScalarObject<'ctx>, Int<'ctx, SizeT>) {
) -> (AnyObject<'ctx>, Int<'ctx, SizeT>) {
let sizet_model = IntModel(SizeT);
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
@ -478,17 +48,17 @@ impl<'ctx> NDArrayObject<'ctx> {
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
let old_extremum = AnyObject { ty: self.dtype, value: old_extremum };
let scalar = nditer.get_scalar(generator, ctx);
let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar);
let new_extremum = AnyObject::call_min_or_max(ctx, kind, old_extremum, scalar);
gen_if_callback(
generator,
ctx,
|generator, ctx| {
// Is new_extremum is more extreme than old_extremum?
let cmp = ScalarObject::compare(
let cmp = AnyObject::compare_int_or_float_by_predicate(
generator,
ctx,
new_extremum,
@ -517,7 +87,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
let extremum = ScalarObject { dtype: self.dtype, value: extremum };
let extremum = AnyObject { ty: self.dtype, value: extremum };
(extremum, extremum_index)
}
@ -528,7 +98,7 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
) -> ScalarObject<'ctx> {
) -> AnyObject<'ctx> {
let on_empty_err_msg = format!(
"zero-size array to reduction operation {} which has no identity",
match kind {

View File

@ -1,9 +1,8 @@
use inkwell::values::BasicValueEnum;
use itertools::Itertools;
use crate::{
codegen::{
object::ndarray::{NDArrayObject, ScalarObject},
object::ndarray::{AnyObject, NDArrayObject},
stmt::gen_for_callback,
CodeGenContext, CodeGenerator,
},
@ -27,8 +26,8 @@ impl<'ctx> NDArrayObject<'ctx> {
MappingFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
&[ScalarObject<'ctx>],
) -> Result<BasicValueEnum<'ctx>, String>,
&[AnyObject<'ctx>],
) -> Result<AnyObject<'ctx>, String>,
{
// Broadcast inputs
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
@ -89,9 +88,11 @@ impl<'ctx> NDArrayObject<'ctx> {
in_nditers.iter().map(|nditer| nditer.get_scalar(generator, ctx)).collect_vec();
let result = mapping(generator, ctx, &in_scalars)?;
// Sanity check on result's ty
assert!(ctx.unifier.unioned(result.ty, out_ndarray.dtype));
let p = out_nditer.get_pointer(generator, ctx);
ctx.builder.build_store(p, result).unwrap();
ctx.builder.build_store(p, result.value).unwrap();
Ok(())
},
@ -103,20 +104,6 @@ impl<'ctx> NDArrayObject<'ctx> {
},
)?;
// let start = sizet_model.const_0(generator, ctx.ctx);
// let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
// let step = sizet_model.const_1(generator, ctx.ctx);
// gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| {
// let elements =
// ndarrays.iter().map(|ndarray| ndarray.get_nth_scalar(generator, ctx, i)).collect_vec();
// let ret = mapping(generator, ctx, i, &elements)?;
// let pret = out_ndarray.get_nth_pointer(generator, ctx, i, "pret");
// ctx.builder.build_store(pret, ret).unwrap();
// Ok(())
// })?;
Ok(out_ndarray)
}
@ -132,8 +119,8 @@ impl<'ctx> NDArrayObject<'ctx> {
Mapping: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
ScalarObject<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
AnyObject<'ctx>,
) -> Result<AnyObject<'ctx>, String>,
{
NDArrayObject::broadcasting_starmap(
generator,
@ -160,16 +147,18 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
MappingFn: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
&[ScalarObject<'ctx>],
) -> Result<BasicValueEnum<'ctx>, String>,
&[AnyObject<'ctx>],
) -> Result<AnyObject<'ctx>, String>,
{
// Check if all inputs are ScalarObjects
let all_scalars: Option<Vec<_>> =
inputs.iter().map(ScalarObject::try_from).try_collect().ok();
// Check if all inputs are AnyObjects
let all_scalars: Option<Vec<_>> = inputs.iter().map(AnyObject::try_from).try_collect().ok();
if let Some(scalars) = all_scalars {
let scalar =
ScalarObject { value: mapping(generator, ctx, &scalars)?, dtype: ret_dtype };
let scalar = mapping(generator, ctx, &scalars)?;
// Sanity check on scalar's type
assert!(ctx.unifier.unioned(scalar.ty, ret_dtype));
Ok(ScalarOrNDArray::Scalar(scalar))
} else {
// Promote all input to ndarrays and map through them.
@ -197,8 +186,8 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
Mapping: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
ScalarObject<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
AnyObject<'ctx>,
) -> Result<AnyObject<'ctx>, String>,
{
ScalarOrNDArray::broadcasting_starmap(
generator,

View File

@ -38,7 +38,7 @@ use inkwell::{
AddressSpace, IntPredicate,
};
use nditer::NDIterHandle;
use scalar::{ScalarObject, ScalarOrNDArray};
use scalar::ScalarOrNDArray;
use util::call_memcpy_model;
use super::{tuple::TupleObject, AnyObject};
@ -257,10 +257,10 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Int<'ctx, SizeT>,
) -> ScalarObject<'ctx> {
) -> AnyObject<'ctx> {
let p = self.get_nth_pointer(generator, ctx, nth, "value");
let value = ctx.builder.build_load(p, "value").unwrap();
ScalarObject { dtype: self.dtype, value }
AnyObject { ty: self.dtype, value }
}
/// Set the n-th (0-based) scalar.
@ -271,10 +271,10 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
nth: Int<'ctx, SizeT>,
scalar: ScalarObject<'ctx>,
scalar: AnyObject<'ctx>,
) {
// Sanity check on scalar's `dtype`
assert!(ctx.unifier.unioned(scalar.dtype, self.dtype));
assert!(ctx.unifier.unioned(scalar.ty, self.dtype));
let pscalar = self.get_nth_pointer(generator, ctx, nth, "pscalar");
ctx.builder.build_store(pscalar, scalar.value).unwrap();
@ -284,7 +284,7 @@ impl<'ctx> NDArrayObject<'ctx> {
///
/// Please refer to the IRRT implementation to see its purpose.
pub fn update_strides_by_shape<G: CodeGenerator + ?Sized>(
&self,
self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) {
@ -458,7 +458,7 @@ impl<'ctx> NDArrayObject<'ctx> {
self.ndims == 0
}
/// If this ndarray is unsized, return its sole value as a [`ScalarObject`]. Otherwise, do nothing.
/// If this ndarray is unsized, return its sole value as a [`AnyObject`]. Otherwise, do nothing.
pub fn split_unsized<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
@ -604,10 +604,10 @@ impl<'ctx> NDArrayObject<'ctx> {
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
scalar: ScalarObject<'ctx>,
scalar: AnyObject<'ctx>,
) {
// Sanity check on scalar's type.
assert!(ctx.unifier.unioned(self.dtype, scalar.dtype));
assert!(ctx.unifier.unioned(self.dtype, scalar.ty));
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
let p = nditer.get_pointer(generator, ctx);

View File

@ -3,7 +3,7 @@ use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
use crate::codegen::{
irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next},
model::*,
object::ndarray::scalar::ScalarObject,
object::AnyObject,
structure::NDIter,
CodeGenContext, CodeGenerator,
};
@ -63,10 +63,10 @@ impl<'ctx> NDIterHandle<'ctx> {
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> ScalarObject<'ctx> {
) -> AnyObject<'ctx> {
let p = self.get_pointer(generator, ctx);
let value = ctx.builder.build_load(p, "value").unwrap();
ScalarObject { dtype: self.ndarray.dtype, value }
AnyObject { ty: self.ndarray.dtype, value }
}
pub fn get_index<G: CodeGenerator + ?Sized>(

View File

@ -7,18 +7,7 @@ use crate::{
use super::NDArrayObject;
/// An LLVM numpy scalar with its [`Type`].
///
/// Intended to be used with [`ScalarOrNDArray`].
///
/// A scalar does not have to be an actual number. It could be arbitrary objects.
#[derive(Debug, Clone, Copy)]
pub struct ScalarObject<'ctx> {
pub dtype: Type,
pub value: BasicValueEnum<'ctx>,
}
impl<'ctx> ScalarObject<'ctx> {
impl<'ctx> AnyObject<'ctx> {
/// Promote this scalar to an unsized ndarray (like doing `np.asarray`).
///
/// The scalar value is allocated onto the stack, and the ndarray's `data` will point to that
@ -35,7 +24,7 @@ impl<'ctx> ScalarObject<'ctx> {
ctx.builder.build_store(data, self.value).unwrap();
let data = pbyte_model.pointer_cast(generator, ctx, data, "data");
let ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, 0, "scalar_ndarray");
let ndarray = NDArrayObject::alloca(generator, ctx, self.ty, 0, "scalar_ndarray");
ndarray.instance.set(ctx, |f| f.data, data);
ndarray
}
@ -44,7 +33,7 @@ impl<'ctx> ScalarObject<'ctx> {
/// A convenience enum for implementing scalar/ndarray agnostic utilities.
#[derive(Debug, Clone, Copy)]
pub enum ScalarOrNDArray<'ctx> {
Scalar(ScalarObject<'ctx>),
Scalar(AnyObject<'ctx>),
NDArray(NDArrayObject<'ctx>),
}
@ -59,7 +48,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
}
#[must_use]
pub fn into_scalar(&self) -> ScalarObject<'ctx> {
pub fn into_scalar(&self) -> AnyObject<'ctx> {
match self {
ScalarOrNDArray::NDArray(_ndarray) => panic!("Got NDArray"),
ScalarOrNDArray::Scalar(scalar) => *scalar,
@ -91,13 +80,13 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
#[must_use]
pub fn dtype(&self) -> Type {
match self {
ScalarOrNDArray::Scalar(scalar) => scalar.dtype,
ScalarOrNDArray::Scalar(scalar) => scalar.ty,
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
}
}
}
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> {
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for AnyObject<'ctx> {
type Error = ();
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
@ -135,7 +124,7 @@ pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
ScalarOrNDArray::NDArray(ndarray)
}
_ => {
let scalar = ScalarObject { dtype: object.ty, value: object.value };
let scalar = AnyObject { ty: object.ty, value: object.value };
ScalarOrNDArray::Scalar(scalar)
}
}

View File

@ -1143,12 +1143,12 @@ impl<'a> BuiltinBuilder<'a> {
ret_dtype,
|generator, ctx, scalar| {
let result = match prim {
PrimDef::FunInt32 => scalar.cast_to_int32(generator, ctx),
PrimDef::FunInt64 => scalar.cast_to_int64(generator, ctx),
PrimDef::FunUInt32 => scalar.cast_to_uint32(generator, ctx),
PrimDef::FunUInt64 => scalar.cast_to_uint64(generator, ctx),
PrimDef::FunFloat => scalar.cast_to_float(ctx),
PrimDef::FunBool => scalar.cast_to_bool(ctx),
PrimDef::FunInt32 => scalar.call_int32(generator, ctx),
PrimDef::FunInt64 => scalar.call_int64(generator, ctx),
PrimDef::FunUInt32 => scalar.call_uint32(generator, ctx),
PrimDef::FunUInt64 => scalar.call_uint64(generator, ctx),
PrimDef::FunFloat => scalar.call_float(ctx),
PrimDef::FunBool => scalar.call_bool(ctx),
_ => unreachable!(),
};
Ok(result.value)
@ -1277,7 +1277,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx,
int_sized,
|generator, ctx, scalar| {
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value)
Ok(scalar.call_floor_or_ceil(generator, ctx, kind, int_sized).value)
},
)?;
Ok(Some(result.to_basic_value_enum()))
@ -1927,7 +1927,7 @@ impl<'a> BuiltinBuilder<'a> {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { value: arg, ty: arg_ty };
Ok(Some(arg.len(generator, ctx).value.as_basic_value_enum()))
Ok(Some(arg.call_len(generator, ctx).value.as_basic_value_enum()))
},
)))),
loc: None,