forked from M-Labs/nac3
WIP: core/ndstrides: remove ScalarObject
This commit is contained in:
parent
f8b934096d
commit
4b765cfb27
@ -1,18 +1,71 @@
|
||||
use inkwell::values::BasicValueEnum;
|
||||
use inkwell::{
|
||||
values::{BasicValue, BasicValueEnum, FloatValue, IntValue},
|
||||
FloatPredicate, IntPredicate,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use list::ListObject;
|
||||
use ndarray::NDArrayObject;
|
||||
use ndarray::{NDArrayObject, NDArrayOut};
|
||||
use range::RangeObject;
|
||||
use tuple::TupleObject;
|
||||
|
||||
use crate::typecheck::typedef::{Type, TypeEnum};
|
||||
use crate::{
|
||||
toplevel::helper::PrimDef,
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
|
||||
use super::{model::*, CodeGenContext, CodeGenerator};
|
||||
use super::{llvm_intrinsics, model::*, CodeGenContext, CodeGenerator};
|
||||
|
||||
pub mod list;
|
||||
pub mod ndarray;
|
||||
pub mod range;
|
||||
pub mod tuple;
|
||||
|
||||
/// Convenience function to crash the program when types of arguments are not supported.
|
||||
/// Used to be debugged with a stacktrace.
|
||||
fn unsupported_type<I>(ctx: &CodeGenContext<'_, '_>, tys: I) -> !
|
||||
where
|
||||
I: IntoIterator<Item = Type>,
|
||||
{
|
||||
unreachable!(
|
||||
"unsupported types found '{}'",
|
||||
tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum FloorOrCeil {
|
||||
Floor,
|
||||
Ceil,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum MinOrMax {
|
||||
Min,
|
||||
Max,
|
||||
}
|
||||
|
||||
fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||
vec![ctx.primitives.int32, ctx.primitives.int64]
|
||||
}
|
||||
|
||||
fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||
vec![ctx.primitives.uint32, ctx.primitives.uint64]
|
||||
}
|
||||
|
||||
fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||
vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64]
|
||||
}
|
||||
|
||||
fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||
vec![
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.uint64,
|
||||
]
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct AnyObject<'ctx> {
|
||||
pub ty: Type,
|
||||
@ -20,13 +73,260 @@ pub struct AnyObject<'ctx> {
|
||||
}
|
||||
|
||||
impl<'ctx> AnyObject<'ctx> {
|
||||
// Get the `len()` of this object.
|
||||
pub fn len<G: CodeGenerator + ?Sized>(
|
||||
/// Returns true if this object's type is a [`TypeEnum::TObj`] and has the object ID as `prim`.
|
||||
pub fn is_obj(&self, ctx: &mut CodeGenContext<'ctx, '_>, prim: PrimDef) -> bool {
|
||||
match &*ctx.unifier.get_ty(self.ty) {
|
||||
TypeEnum::TObj { obj_id, .. } => *obj_id == prim.id(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if this object's type is a [`TypeEnum::TTuple`]
|
||||
pub fn is_tuple(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||
matches!(&*ctx.unifier.get_ty(self.ty), TypeEnum::TTuple { .. })
|
||||
}
|
||||
|
||||
pub fn is_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||
ctx.unifier.unioned(self.ty, ctx.primitives.int32)
|
||||
}
|
||||
|
||||
pub fn is_uint32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||
ctx.unifier.unioned(self.ty, ctx.primitives.uint32)
|
||||
}
|
||||
|
||||
pub fn is_int64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||
ctx.unifier.unioned(self.ty, ctx.primitives.int64)
|
||||
}
|
||||
|
||||
pub fn is_uint64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||
ctx.unifier.unioned(self.ty, ctx.primitives.uint64)
|
||||
}
|
||||
|
||||
pub fn is_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||
ctx.unifier.unioned(self.ty, ctx.primitives.bool)
|
||||
}
|
||||
|
||||
pub fn is_int_like(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||
ctx.unifier.unioned_any(self.ty, int_like(ctx))
|
||||
}
|
||||
|
||||
pub fn is_signed_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||
ctx.unifier.unioned_any(self.ty, signed_ints(ctx))
|
||||
}
|
||||
|
||||
pub fn is_unsigned_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||
ctx.unifier.unioned_any(self.ty, unsigned_ints(ctx))
|
||||
}
|
||||
|
||||
/// Convenience function. If object has type `int32`, `int64`, `uint32`, `uint64`, or `bool`,
|
||||
/// get its underlying LLVM value.
|
||||
///
|
||||
/// Panic if the type is wrong.
|
||||
pub fn into_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
if self.is_int_like(ctx) {
|
||||
self.value.into_int_value()
|
||||
} else {
|
||||
panic!("not an int32 type")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||
self.is_obj(ctx, PrimDef::Float)
|
||||
}
|
||||
|
||||
/// Convenience function. If object has type `float`, get its underlying LLVM value.
|
||||
///
|
||||
/// Panic if the type is wrong.
|
||||
pub fn into_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Float<'ctx, Float64> {
|
||||
if self.is_float(ctx) {
|
||||
// self.value must be a FloatValue
|
||||
FloatModel(Float64).believe_value(self.value.into_float_value())
|
||||
} else {
|
||||
panic!("not a float type")
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_ndarray(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||
self.is_obj(ctx, PrimDef::NDArray)
|
||||
}
|
||||
|
||||
pub fn into_ndarray<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Int<'ctx, Int32> {
|
||||
match &*ctx.unifier.get_ty_immutable(self.ty) {
|
||||
) -> NDArrayObject<'ctx> {
|
||||
NDArrayObject::from_object(generator, ctx, *self)
|
||||
}
|
||||
|
||||
/// Helper function to compare two scalars.
|
||||
///
|
||||
/// Only int-to-int and float-to-float comparisons are allowed.
|
||||
///
|
||||
/// Panic otherwise.
|
||||
pub fn compare_int_or_float_by_predicate<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
lhs: AnyObject<'ctx>,
|
||||
rhs: AnyObject<'ctx>,
|
||||
int_predicate: IntPredicate,
|
||||
float_predicate: FloatPredicate,
|
||||
name: &str,
|
||||
) -> Int<'ctx, Bool> {
|
||||
if !ctx.unifier.unioned(lhs.ty, rhs.ty) {
|
||||
panic!("lhs and rhs type are not the same.")
|
||||
}
|
||||
|
||||
let bool_model = IntModel(Bool);
|
||||
|
||||
let common_ty = lhs.ty;
|
||||
let result = if lhs.is_float(ctx) {
|
||||
let lhs = lhs.into_float(ctx);
|
||||
let rhs = rhs.into_float(ctx);
|
||||
ctx.builder.build_float_compare(float_predicate, lhs.value, rhs.value, name).unwrap()
|
||||
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
|
||||
let lhs = lhs.into_int(ctx);
|
||||
let rhs = rhs.into_int(ctx);
|
||||
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [lhs.ty, rhs.ty]);
|
||||
};
|
||||
|
||||
bool_model.check_value(generator, ctx.ctx, result).unwrap()
|
||||
}
|
||||
|
||||
/// Helper function for `int32()`, `int64()`, `uint32()`, and `uint64()`.
|
||||
///
|
||||
/// TODO: Document me
|
||||
fn cast_to_int_conversion<'a, G, HandleFloatFn>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ret_int_ty: Type,
|
||||
handle_float: HandleFloatFn,
|
||||
) -> AnyObject<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
HandleFloatFn:
|
||||
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
|
||||
{
|
||||
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
|
||||
|
||||
let result = if self.is_float(ctx) {
|
||||
// Handle float to int
|
||||
let n = self.into_float(ctx);
|
||||
handle_float(generator, ctx, n.value)
|
||||
} else if self.is_int_like(ctx) {
|
||||
// Handle int to a new int type
|
||||
let n = self.into_int(ctx);
|
||||
if n.get_type().get_bit_width() <= ret_int_ty_llvm.get_bit_width() {
|
||||
ctx.builder.build_int_z_extend(n, ret_int_ty_llvm, "zext").unwrap()
|
||||
} else {
|
||||
ctx.builder.build_int_truncate(n, ret_int_ty_llvm, "trunc").unwrap()
|
||||
}
|
||||
} else {
|
||||
unsupported_type(ctx, [self.ty]);
|
||||
};
|
||||
|
||||
assert_eq!(ret_int_ty_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
|
||||
AnyObject { value: result.into(), ty: ret_int_ty }
|
||||
}
|
||||
|
||||
/// Call `int32()` on this object.
|
||||
#[must_use]
|
||||
pub fn call_int32<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> AnyObject<'ctx> {
|
||||
self.cast_to_int_conversion(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.int32,
|
||||
|_generator, ctx, input| {
|
||||
let n =
|
||||
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap();
|
||||
ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Call `int64()` on this object.
|
||||
#[must_use]
|
||||
pub fn call_int64<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> AnyObject<'ctx> {
|
||||
self.cast_to_int_conversion(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.int64,
|
||||
|_generator, ctx, input| {
|
||||
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Call `uint32()` on this object.
|
||||
#[must_use]
|
||||
pub fn call_uint32<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> AnyObject<'ctx> {
|
||||
self.cast_to_int_conversion(generator, ctx, ctx.primitives.uint32, |_generator, ctx, n| {
|
||||
let n_gez = ctx
|
||||
.builder
|
||||
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
|
||||
.unwrap();
|
||||
|
||||
let to_int32 =
|
||||
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap();
|
||||
let to_uint64 =
|
||||
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_select(
|
||||
n_gez,
|
||||
ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(),
|
||||
to_int32,
|
||||
"conv",
|
||||
)
|
||||
.unwrap()
|
||||
.into_int_value()
|
||||
})
|
||||
}
|
||||
|
||||
/// Call `uint64()` on this object.
|
||||
#[must_use]
|
||||
pub fn call_uint64<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> AnyObject<'ctx> {
|
||||
self.cast_to_int_conversion(generator, ctx, ctx.primitives.uint64, |_generator, ctx, n| {
|
||||
let val_gez = ctx
|
||||
.builder
|
||||
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
|
||||
.unwrap();
|
||||
|
||||
let to_int64 =
|
||||
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||
let to_uint64 =
|
||||
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||
|
||||
ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap().into_int_value()
|
||||
})
|
||||
}
|
||||
|
||||
// Get the `len()` of this object.
|
||||
pub fn call_len<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> AnyObject<'ctx> {
|
||||
// TODO: Switch to returning SizeT
|
||||
let result = match &*ctx.unifier.get_ty_immutable(self.ty) {
|
||||
TypeEnum::TTuple { .. } => {
|
||||
let tuple = TupleObject::from_object(ctx, *self);
|
||||
tuple.len(generator, ctx).truncate(generator, ctx, Int32, "tuple_len_32")
|
||||
@ -50,6 +350,259 @@ impl<'ctx> AnyObject<'ctx> {
|
||||
ndarray.len(generator, ctx).truncate(generator, ctx, Int32, "ndarray_len_i32")
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
AnyObject { ty: ctx.primitives.int32, value: result.value.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Like [`AnyObject::call_bool`] but this returns an `Int<'ctx, Bool>` instead of an object.
|
||||
pub fn bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Int<'ctx, Bool> {
|
||||
let bool_model = IntModel(Bool);
|
||||
if self.is_int_like(ctx) {
|
||||
let n = self.into_int(ctx);
|
||||
let n = ctx
|
||||
.builder
|
||||
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
|
||||
.unwrap();
|
||||
bool_model.believe_value(n)
|
||||
} else if self.is_float(ctx) {
|
||||
let n = self.value.into_float_value();
|
||||
let n = ctx
|
||||
.builder
|
||||
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
|
||||
.unwrap();
|
||||
bool_model.believe_value(n)
|
||||
} else {
|
||||
unsupported_type(ctx, [self.ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Call `bool()` on this object.
|
||||
pub fn call_bool<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> AnyObject<'ctx> {
|
||||
let n = self.bool(ctx);
|
||||
|
||||
// NAC3 booleans are i8
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let n = ctx.builder.build_int_z_extend(n.value, llvm_i8, "bool").unwrap();
|
||||
|
||||
AnyObject { ty: ctx.primitives.bool, value: n.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Call `float()` on this object.
|
||||
pub fn call_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
|
||||
let f64_model = FloatModel(Float64);
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let result = if self.is_float(ctx) {
|
||||
self.into_float(ctx)
|
||||
} else if self.is_signed_int(ctx) || self.is_bool(ctx) {
|
||||
let n = self.into_int(ctx);
|
||||
let n = ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap();
|
||||
f64_model.believe_value(n)
|
||||
} else if self.is_unsigned_int(ctx) {
|
||||
let n = self.into_int(ctx);
|
||||
let n = ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap();
|
||||
f64_model.believe_value(n)
|
||||
} else {
|
||||
unsupported_type(ctx, [self.ty]);
|
||||
};
|
||||
|
||||
AnyObject { ty: ctx.primitives.float, value: result.value.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
// Call `abs()` on this object.
|
||||
#[must_use]
|
||||
pub fn call_abs<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Self {
|
||||
if self.is_float(ctx) {
|
||||
let n = self.value.into_float_value();
|
||||
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
|
||||
AnyObject { value: n.into(), ty: ctx.primitives.float }
|
||||
} else if self.is_unsigned_int(ctx) || self.is_signed_int(ctx) {
|
||||
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
|
||||
|
||||
let n = self.value.into_int_value();
|
||||
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
|
||||
AnyObject { value: n.into(), ty: self.ty }
|
||||
} else if self.is_ndarray(ctx) {
|
||||
let ndarray = self.into_ndarray(generator, ctx);
|
||||
ndarray
|
||||
.map(
|
||||
generator,
|
||||
ctx,
|
||||
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|
||||
|generator, ctx, scalar| Ok(scalar.call_abs(generator, ctx)),
|
||||
)
|
||||
.unwrap()
|
||||
.to_any_object(ctx)
|
||||
} else {
|
||||
unsupported_type(ctx, [self.ty])
|
||||
}
|
||||
}
|
||||
|
||||
// Call `round()` on this object.
|
||||
//
|
||||
// It is possible to specify which kind of int type to return.
|
||||
pub fn call_round<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ret_int_ty: Type,
|
||||
) -> AnyObject<'ctx> {
|
||||
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
|
||||
|
||||
let result = if ctx.unifier.unioned(self.ty, ctx.primitives.float) {
|
||||
let n = self.value.into_float_value();
|
||||
let n = llvm_intrinsics::call_float_round(ctx, n, None);
|
||||
ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, "round").unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [self.ty])
|
||||
};
|
||||
AnyObject { ty: ret_int_ty, value: result.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Call `np_round()` on this object.
|
||||
///
|
||||
/// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type.
|
||||
#[must_use]
|
||||
pub fn call_np_round<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> AnyObject<'ctx> {
|
||||
if self.is_float(ctx) {
|
||||
let n = self.into_float(ctx);
|
||||
let n = llvm_intrinsics::call_float_rint(ctx, n.value, None);
|
||||
AnyObject { ty: ctx.primitives.float, value: n.as_basic_value_enum() }
|
||||
} else if self.is_ndarray(ctx) {
|
||||
let ndarray = self.into_ndarray(generator, ctx);
|
||||
ndarray
|
||||
.map(
|
||||
generator,
|
||||
ctx,
|
||||
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|
||||
|generator, ctx, scalar| Ok(scalar.call_np_round(generator, ctx)),
|
||||
)
|
||||
.unwrap()
|
||||
.to_any_object(ctx)
|
||||
} else {
|
||||
unsupported_type(ctx, [self.ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Call `min()` or `max()` on two objects.
|
||||
pub fn call_min_or_max(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: MinOrMax,
|
||||
a: AnyObject<'ctx>,
|
||||
b: AnyObject<'ctx>,
|
||||
) -> AnyObject<'ctx> {
|
||||
if !ctx.unifier.unioned(a.ty, b.ty) {
|
||||
unsupported_type(ctx, [a.ty, b.ty])
|
||||
}
|
||||
|
||||
let common_ty = a.ty;
|
||||
|
||||
if a.is_float(ctx) {
|
||||
let function = match kind {
|
||||
MinOrMax::Min => llvm_intrinsics::call_float_minnum,
|
||||
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
|
||||
};
|
||||
|
||||
let a = a.into_float(ctx).value;
|
||||
let b = b.into_float(ctx).value;
|
||||
let result = function(ctx, a, b, None);
|
||||
AnyObject { value: result.as_basic_value_enum(), ty: ctx.primitives.float }
|
||||
} else if a.is_unsigned_int(ctx) || a.is_bool(ctx) {
|
||||
// Treating bool has an unsigned int since that is convenient
|
||||
let function = match kind {
|
||||
MinOrMax::Min => llvm_intrinsics::call_int_umin,
|
||||
MinOrMax::Max => llvm_intrinsics::call_int_umax,
|
||||
};
|
||||
|
||||
let a = a.into_int(ctx);
|
||||
let b = b.into_int(ctx);
|
||||
let result = function(ctx, a, b, None);
|
||||
AnyObject { value: result.as_basic_value_enum(), ty: common_ty }
|
||||
} else if a.is_signed_int(ctx) {
|
||||
let function = match kind {
|
||||
MinOrMax::Min => llvm_intrinsics::call_int_smin,
|
||||
MinOrMax::Max => llvm_intrinsics::call_int_smax,
|
||||
};
|
||||
|
||||
let a = a.into_int(ctx);
|
||||
let b = b.into_int(ctx);
|
||||
let result = function(ctx, a, b, None);
|
||||
AnyObject { value: result.as_basic_value_enum(), ty: common_ty }
|
||||
} else {
|
||||
unsupported_type(ctx, [common_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Call `floor()` or `ceil()` on this object.
|
||||
///
|
||||
/// It is possible to specify which kind of int type to return.
|
||||
#[must_use]
|
||||
pub fn call_floor_or_ceil<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: FloorOrCeil,
|
||||
ret_int_ty: Type,
|
||||
) -> Self {
|
||||
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
|
||||
|
||||
if self.is_float(ctx) {
|
||||
let function = match kind {
|
||||
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
|
||||
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
|
||||
};
|
||||
|
||||
let n = self.into_float(ctx).value;
|
||||
let n = function(ctx, n, None);
|
||||
let n = ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, "").unwrap();
|
||||
AnyObject { ty: ret_int_ty, value: n.as_basic_value_enum() }
|
||||
} else {
|
||||
unsupported_type(ctx, [self.ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Call `np_floor()` or `np_ceil()` on this object.
|
||||
#[must_use]
|
||||
pub fn call_np_floor_or_ceil<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: FloorOrCeil,
|
||||
) -> Self {
|
||||
// TODO:
|
||||
if self.is_float(ctx) {
|
||||
let function = match kind {
|
||||
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
|
||||
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
|
||||
};
|
||||
let n = self.into_float(ctx).value;
|
||||
let n = function(ctx, n, None);
|
||||
AnyObject { ty: ctx.primitives.float, value: n.as_basic_value_enum() }
|
||||
} else if self.is_ndarray(ctx) {
|
||||
let ndarray = self.into_ndarray(generator, ctx);
|
||||
ndarray
|
||||
.map(
|
||||
generator,
|
||||
ctx,
|
||||
NDArrayOut::NewNDArray { dtype: ctx.primitives.float },
|
||||
|generator, ctx, scalar| Ok(scalar.call_np_floor_or_ceil(generator, ctx, kind)),
|
||||
)
|
||||
.unwrap()
|
||||
.to_any_object(ctx)
|
||||
} else {
|
||||
unsupported_type(ctx, [self.ty])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
use inkwell::{values::BasicValueEnum, IntPredicate};
|
||||
|
||||
use super::{scalar::ScalarObject, NDArrayObject};
|
||||
use super::NDArrayObject;
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, CodeGenContext,
|
||||
CodeGenerator,
|
||||
irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, object::AnyObject,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::Type,
|
||||
};
|
||||
@ -93,10 +93,10 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
dtype: Type,
|
||||
ndims: u64,
|
||||
shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||
fill_value: ScalarObject<'ctx>,
|
||||
fill_value: AnyObject<'ctx>,
|
||||
) -> Self {
|
||||
// Sanity check on `fill_value`'s dtype.
|
||||
assert!(ctx.unifier.unioned(dtype, fill_value.dtype));
|
||||
assert!(ctx.unifier.unioned(dtype, fill_value.ty));
|
||||
|
||||
let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape);
|
||||
ndarray.fill(generator, ctx, fill_value);
|
||||
@ -112,7 +112,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||
) -> Self {
|
||||
let fill_value = ndarray_zero_value(generator, ctx, dtype);
|
||||
let fill_value = ScalarObject { value: fill_value, dtype };
|
||||
let fill_value = AnyObject { value: fill_value, ty: dtype };
|
||||
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
|
||||
}
|
||||
|
||||
@ -125,7 +125,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||
) -> Self {
|
||||
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
||||
let fill_value = ScalarObject { value: fill_value, dtype };
|
||||
let fill_value = AnyObject { value: fill_value, ty: dtype };
|
||||
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
|
||||
}
|
||||
|
||||
|
@ -1,443 +1,13 @@
|
||||
use inkwell::{
|
||||
values::{BasicValue, FloatValue, IntValue},
|
||||
FloatPredicate, IntPredicate,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use inkwell::{FloatPredicate, IntPredicate};
|
||||
|
||||
use crate::{
|
||||
codegen::{llvm_intrinsics, model::*, stmt::gen_if_callback, CodeGenContext, CodeGenerator},
|
||||
typecheck::typedef::Type,
|
||||
use crate::codegen::{
|
||||
model::*,
|
||||
object::{AnyObject, MinOrMax},
|
||||
stmt::gen_if_callback,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
use super::{scalar::ScalarObject, NDArrayObject};
|
||||
|
||||
/// Convenience function to crash the program when types of arguments are not supported.
|
||||
/// Used to be debugged with a stacktrace.
|
||||
fn unsupported_type<I>(ctx: &CodeGenContext<'_, '_>, tys: I) -> !
|
||||
where
|
||||
I: IntoIterator<Item = Type>,
|
||||
{
|
||||
unreachable!(
|
||||
"unsupported types found '{}'",
|
||||
tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum FloorOrCeil {
|
||||
Floor,
|
||||
Ceil,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum MinOrMax {
|
||||
Min,
|
||||
Max,
|
||||
}
|
||||
|
||||
fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||
vec![ctx.primitives.int32, ctx.primitives.int64]
|
||||
}
|
||||
|
||||
fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||
vec![ctx.primitives.uint32, ctx.primitives.uint64]
|
||||
}
|
||||
|
||||
fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||
vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64]
|
||||
}
|
||||
|
||||
fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||
vec![
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.uint64,
|
||||
]
|
||||
}
|
||||
|
||||
fn cast_to_int_conversion<'ctx, 'a, G, HandleFloatFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
scalar: ScalarObject<'ctx>,
|
||||
ret_int_dtype: Type,
|
||||
handle_float: HandleFloatFn,
|
||||
) -> ScalarObject<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
HandleFloatFn:
|
||||
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
|
||||
{
|
||||
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
|
||||
|
||||
let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) {
|
||||
// Special handling for floats
|
||||
let n = scalar.value.into_float_value();
|
||||
handle_float(generator, ctx, n)
|
||||
} else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) {
|
||||
let n = scalar.value.into_int_value();
|
||||
|
||||
if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() {
|
||||
ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap()
|
||||
} else {
|
||||
ctx.builder.build_int_truncate(n, ret_int_dtype_llvm, "trunc").unwrap()
|
||||
}
|
||||
} else {
|
||||
unsupported_type(ctx, [scalar.dtype]);
|
||||
};
|
||||
|
||||
assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
|
||||
ScalarObject { value: result.into(), dtype: ret_int_dtype }
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarObject<'ctx> {
|
||||
/// Convenience function. Assume this scalar has typechecker type float64, get its underlying LLVM value.
|
||||
///
|
||||
/// Panic if the type is wrong.
|
||||
pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> {
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
self.value.into_float_value() // self.value must be a FloatValue
|
||||
} else {
|
||||
panic!("not a float type")
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience function. Assume this scalar has typechecker type int32, get its underlying LLVM value.
|
||||
///
|
||||
/// Panic if the type is wrong.
|
||||
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) {
|
||||
let value = self.value.into_int_value();
|
||||
debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check
|
||||
value
|
||||
} else {
|
||||
panic!("not a float type")
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare two scalars. Only int-to-int and float-to-float comparisons are allowed.
|
||||
/// Panic otherwise.
|
||||
pub fn compare<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
lhs: ScalarObject<'ctx>,
|
||||
rhs: ScalarObject<'ctx>,
|
||||
int_predicate: IntPredicate,
|
||||
float_predicate: FloatPredicate,
|
||||
name: &str,
|
||||
) -> Int<'ctx, Bool> {
|
||||
if !ctx.unifier.unioned(lhs.dtype, rhs.dtype) {
|
||||
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
|
||||
}
|
||||
|
||||
let bool_model = IntModel(Bool);
|
||||
|
||||
let common_ty = lhs.dtype;
|
||||
let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) {
|
||||
let lhs = lhs.value.into_float_value();
|
||||
let rhs = rhs.value.into_float_value();
|
||||
ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap()
|
||||
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
|
||||
let lhs = lhs.value.into_int_value();
|
||||
let rhs = rhs.value.into_int_value();
|
||||
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
|
||||
};
|
||||
|
||||
bool_model.check_value(generator, ctx.ctx, result).unwrap()
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `int32()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_int32<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Self {
|
||||
cast_to_int_conversion(
|
||||
generator,
|
||||
ctx,
|
||||
*self,
|
||||
ctx.primitives.int32,
|
||||
|_generator, ctx, input| {
|
||||
let n =
|
||||
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap();
|
||||
ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `int64()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_int64<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Self {
|
||||
cast_to_int_conversion(
|
||||
generator,
|
||||
ctx,
|
||||
*self,
|
||||
ctx.primitives.int64,
|
||||
|_generator, ctx, input| {
|
||||
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `uint32()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_uint32<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Self {
|
||||
cast_to_int_conversion(
|
||||
generator,
|
||||
ctx,
|
||||
*self,
|
||||
ctx.primitives.uint32,
|
||||
|_generator, ctx, n| {
|
||||
let n_gez = ctx
|
||||
.builder
|
||||
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
|
||||
.unwrap();
|
||||
|
||||
let to_int32 =
|
||||
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap();
|
||||
let to_uint64 =
|
||||
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_select(
|
||||
n_gez,
|
||||
ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(),
|
||||
to_int32,
|
||||
"conv",
|
||||
)
|
||||
.unwrap()
|
||||
.into_int_value()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `uint64()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_uint64<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Self {
|
||||
cast_to_int_conversion(
|
||||
generator,
|
||||
ctx,
|
||||
*self,
|
||||
ctx.primitives.uint64,
|
||||
|_generator, ctx, n| {
|
||||
let val_gez = ctx
|
||||
.builder
|
||||
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
|
||||
.unwrap();
|
||||
|
||||
let to_int64 =
|
||||
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||
let to_uint64 =
|
||||
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||
|
||||
ctx.builder
|
||||
.build_select(val_gez, to_uint64, to_int64, "conv")
|
||||
.unwrap()
|
||||
.into_int_value()
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `bool()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
// TODO: Why is the original code being so lax about i1 and i8 for the returned int type?
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) {
|
||||
self.value.into_int_value()
|
||||
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
|
||||
let n = self.value.into_int_value();
|
||||
ctx.builder
|
||||
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
|
||||
.unwrap()
|
||||
} else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.value.into_float_value();
|
||||
ctx.builder
|
||||
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
|
||||
.unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
};
|
||||
|
||||
ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `float()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
self.value.into_float_value()
|
||||
} else if ctx
|
||||
.unifier
|
||||
.unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat())
|
||||
{
|
||||
let n = self.value.into_int_value();
|
||||
ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap()
|
||||
} else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) {
|
||||
let n = self.value.into_int_value();
|
||||
ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype]);
|
||||
};
|
||||
|
||||
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `round()`.
|
||||
#[must_use]
|
||||
pub fn round<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ret_int_dtype: Type,
|
||||
) -> Self {
|
||||
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
|
||||
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.value.into_float_value();
|
||||
let n = llvm_intrinsics::call_float_round(ctx, n, None);
|
||||
ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype, ret_int_dtype])
|
||||
};
|
||||
ScalarObject { dtype: ret_int_dtype, value: result.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `np_round()`.
|
||||
///
|
||||
/// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type.
|
||||
#[must_use]
|
||||
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.value.into_float_value();
|
||||
llvm_intrinsics::call_float_rint(ctx, n, None)
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
};
|
||||
ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `min()` or `max()`.
|
||||
pub fn min_or_max(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: MinOrMax,
|
||||
a: Self,
|
||||
b: Self,
|
||||
) -> Self {
|
||||
if !ctx.unifier.unioned(a.dtype, b.dtype) {
|
||||
unsupported_type(ctx, [a.dtype, b.dtype])
|
||||
}
|
||||
|
||||
let common_dtype = a.dtype;
|
||||
|
||||
if ctx.unifier.unioned(common_dtype, ctx.primitives.float) {
|
||||
let function = match kind {
|
||||
MinOrMax::Min => llvm_intrinsics::call_float_minnum,
|
||||
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
|
||||
};
|
||||
let result =
|
||||
function(ctx, a.value.into_float_value(), b.value.into_float_value(), None);
|
||||
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
|
||||
} else if ctx.unifier.unioned_any(
|
||||
common_dtype,
|
||||
[unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(),
|
||||
) {
|
||||
// Treating bool has an unsigned int since that is convenient
|
||||
let function = match kind {
|
||||
MinOrMax::Min => llvm_intrinsics::call_int_umin,
|
||||
MinOrMax::Max => llvm_intrinsics::call_int_umax,
|
||||
};
|
||||
let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None);
|
||||
ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype }
|
||||
} else {
|
||||
unsupported_type(ctx, [common_dtype])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `floor()` or `ceil()`.
|
||||
///
|
||||
/// * `ret_int_dtype` - The type of int to return.
|
||||
///
|
||||
/// Takes in a float/int and returns an int of type `ret_int_dtype`
|
||||
#[must_use]
|
||||
pub fn floor_or_ceil<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: FloorOrCeil,
|
||||
ret_int_dtype: Type,
|
||||
) -> Self {
|
||||
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
|
||||
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let function = match kind {
|
||||
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
|
||||
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
|
||||
};
|
||||
let n = self.value.into_float_value();
|
||||
let n = function(ctx, n, None);
|
||||
|
||||
let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap();
|
||||
ScalarObject { dtype: ret_int_dtype, value: n.as_basic_value_enum() }
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `np_floor()`/ `np_ceil()`.
|
||||
///
|
||||
/// Takes in a float/int and returns a float64 result.
|
||||
#[must_use]
|
||||
pub fn np_floor_or_ceil(&self, ctx: &mut CodeGenContext<'ctx, '_>, kind: FloorOrCeil) -> Self {
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let function = match kind {
|
||||
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
|
||||
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
|
||||
};
|
||||
let n = self.value.into_float_value();
|
||||
let n = function(ctx, n, None);
|
||||
ScalarObject { dtype: ctx.primitives.float, value: n.as_basic_value_enum() }
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `abs()`.
|
||||
#[must_use]
|
||||
pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.value.into_float_value();
|
||||
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
|
||||
ScalarObject { value: n.into(), dtype: ctx.primitives.float }
|
||||
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
|
||||
let n = self.value.into_int_value();
|
||||
|
||||
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
|
||||
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
|
||||
|
||||
ScalarObject { value: n.into(), dtype: self.dtype }
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
}
|
||||
}
|
||||
}
|
||||
use super::NDArrayObject;
|
||||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// Helper function to implement NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
|
||||
@ -451,7 +21,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: MinOrMax,
|
||||
on_empty_err_msg: &str,
|
||||
) -> (ScalarObject<'ctx>, Int<'ctx, SizeT>) {
|
||||
) -> (AnyObject<'ctx>, Int<'ctx, SizeT>) {
|
||||
let sizet_model = IntModel(SizeT);
|
||||
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
|
||||
|
||||
@ -478,17 +48,17 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
|
||||
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
|
||||
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
|
||||
let old_extremum = AnyObject { ty: self.dtype, value: old_extremum };
|
||||
|
||||
let scalar = nditer.get_scalar(generator, ctx);
|
||||
let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar);
|
||||
let new_extremum = AnyObject::call_min_or_max(ctx, kind, old_extremum, scalar);
|
||||
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
// Is new_extremum is more extreme than old_extremum?
|
||||
let cmp = ScalarObject::compare(
|
||||
let cmp = AnyObject::compare_int_or_float_by_predicate(
|
||||
generator,
|
||||
ctx,
|
||||
new_extremum,
|
||||
@ -517,7 +87,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
|
||||
|
||||
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
|
||||
let extremum = ScalarObject { dtype: self.dtype, value: extremum };
|
||||
let extremum = AnyObject { ty: self.dtype, value: extremum };
|
||||
|
||||
(extremum, extremum_index)
|
||||
}
|
||||
@ -528,7 +98,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: MinOrMax,
|
||||
) -> ScalarObject<'ctx> {
|
||||
) -> AnyObject<'ctx> {
|
||||
let on_empty_err_msg = format!(
|
||||
"zero-size array to reduction operation {} which has no identity",
|
||||
match kind {
|
||||
|
@ -1,9 +1,8 @@
|
||||
use inkwell::values::BasicValueEnum;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
object::ndarray::{NDArrayObject, ScalarObject},
|
||||
object::ndarray::{AnyObject, NDArrayObject},
|
||||
stmt::gen_for_callback,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
@ -27,8 +26,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
MappingFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
&[ScalarObject<'ctx>],
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
&[AnyObject<'ctx>],
|
||||
) -> Result<AnyObject<'ctx>, String>,
|
||||
{
|
||||
// Broadcast inputs
|
||||
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
|
||||
@ -89,9 +88,11 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
in_nditers.iter().map(|nditer| nditer.get_scalar(generator, ctx)).collect_vec();
|
||||
|
||||
let result = mapping(generator, ctx, &in_scalars)?;
|
||||
// Sanity check on result's ty
|
||||
assert!(ctx.unifier.unioned(result.ty, out_ndarray.dtype));
|
||||
|
||||
let p = out_nditer.get_pointer(generator, ctx);
|
||||
ctx.builder.build_store(p, result).unwrap();
|
||||
ctx.builder.build_store(p, result.value).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
@ -103,20 +104,6 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
},
|
||||
)?;
|
||||
|
||||
// let start = sizet_model.const_0(generator, ctx.ctx);
|
||||
// let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
|
||||
// let step = sizet_model.const_1(generator, ctx.ctx);
|
||||
// gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| {
|
||||
// let elements =
|
||||
// ndarrays.iter().map(|ndarray| ndarray.get_nth_scalar(generator, ctx, i)).collect_vec();
|
||||
|
||||
// let ret = mapping(generator, ctx, i, &elements)?;
|
||||
|
||||
// let pret = out_ndarray.get_nth_pointer(generator, ctx, i, "pret");
|
||||
// ctx.builder.build_store(pret, ret).unwrap();
|
||||
// Ok(())
|
||||
// })?;
|
||||
|
||||
Ok(out_ndarray)
|
||||
}
|
||||
|
||||
@ -132,8 +119,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
Mapping: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
ScalarObject<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
AnyObject<'ctx>,
|
||||
) -> Result<AnyObject<'ctx>, String>,
|
||||
{
|
||||
NDArrayObject::broadcasting_starmap(
|
||||
generator,
|
||||
@ -160,16 +147,18 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
MappingFn: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
&[ScalarObject<'ctx>],
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
&[AnyObject<'ctx>],
|
||||
) -> Result<AnyObject<'ctx>, String>,
|
||||
{
|
||||
// Check if all inputs are ScalarObjects
|
||||
let all_scalars: Option<Vec<_>> =
|
||||
inputs.iter().map(ScalarObject::try_from).try_collect().ok();
|
||||
// Check if all inputs are AnyObjects
|
||||
let all_scalars: Option<Vec<_>> = inputs.iter().map(AnyObject::try_from).try_collect().ok();
|
||||
|
||||
if let Some(scalars) = all_scalars {
|
||||
let scalar =
|
||||
ScalarObject { value: mapping(generator, ctx, &scalars)?, dtype: ret_dtype };
|
||||
let scalar = mapping(generator, ctx, &scalars)?;
|
||||
|
||||
// Sanity check on scalar's type
|
||||
assert!(ctx.unifier.unioned(scalar.ty, ret_dtype));
|
||||
|
||||
Ok(ScalarOrNDArray::Scalar(scalar))
|
||||
} else {
|
||||
// Promote all input to ndarrays and map through them.
|
||||
@ -197,8 +186,8 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
Mapping: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
ScalarObject<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
AnyObject<'ctx>,
|
||||
) -> Result<AnyObject<'ctx>, String>,
|
||||
{
|
||||
ScalarOrNDArray::broadcasting_starmap(
|
||||
generator,
|
||||
|
@ -38,7 +38,7 @@ use inkwell::{
|
||||
AddressSpace, IntPredicate,
|
||||
};
|
||||
use nditer::NDIterHandle;
|
||||
use scalar::{ScalarObject, ScalarOrNDArray};
|
||||
use scalar::ScalarOrNDArray;
|
||||
use util::call_memcpy_model;
|
||||
|
||||
use super::{tuple::TupleObject, AnyObject};
|
||||
@ -257,10 +257,10 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
nth: Int<'ctx, SizeT>,
|
||||
) -> ScalarObject<'ctx> {
|
||||
) -> AnyObject<'ctx> {
|
||||
let p = self.get_nth_pointer(generator, ctx, nth, "value");
|
||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||
ScalarObject { dtype: self.dtype, value }
|
||||
AnyObject { ty: self.dtype, value }
|
||||
}
|
||||
|
||||
/// Set the n-th (0-based) scalar.
|
||||
@ -271,10 +271,10 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
nth: Int<'ctx, SizeT>,
|
||||
scalar: ScalarObject<'ctx>,
|
||||
scalar: AnyObject<'ctx>,
|
||||
) {
|
||||
// Sanity check on scalar's `dtype`
|
||||
assert!(ctx.unifier.unioned(scalar.dtype, self.dtype));
|
||||
assert!(ctx.unifier.unioned(scalar.ty, self.dtype));
|
||||
|
||||
let pscalar = self.get_nth_pointer(generator, ctx, nth, "pscalar");
|
||||
ctx.builder.build_store(pscalar, scalar.value).unwrap();
|
||||
@ -284,7 +284,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
///
|
||||
/// Please refer to the IRRT implementation to see its purpose.
|
||||
pub fn update_strides_by_shape<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) {
|
||||
@ -458,7 +458,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
self.ndims == 0
|
||||
}
|
||||
|
||||
/// If this ndarray is unsized, return its sole value as a [`ScalarObject`]. Otherwise, do nothing.
|
||||
/// If this ndarray is unsized, return its sole value as a [`AnyObject`]. Otherwise, do nothing.
|
||||
pub fn split_unsized<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
@ -604,10 +604,10 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
scalar: ScalarObject<'ctx>,
|
||||
scalar: AnyObject<'ctx>,
|
||||
) {
|
||||
// Sanity check on scalar's type.
|
||||
assert!(ctx.unifier.unioned(self.dtype, scalar.dtype));
|
||||
assert!(ctx.unifier.unioned(self.dtype, scalar.ty));
|
||||
|
||||
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||
let p = nditer.get_pointer(generator, ctx);
|
||||
|
@ -3,7 +3,7 @@ use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
|
||||
use crate::codegen::{
|
||||
irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next},
|
||||
model::*,
|
||||
object::ndarray::scalar::ScalarObject,
|
||||
object::AnyObject,
|
||||
structure::NDIter,
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
@ -63,10 +63,10 @@ impl<'ctx> NDIterHandle<'ctx> {
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> ScalarObject<'ctx> {
|
||||
) -> AnyObject<'ctx> {
|
||||
let p = self.get_pointer(generator, ctx);
|
||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||
ScalarObject { dtype: self.ndarray.dtype, value }
|
||||
AnyObject { ty: self.ndarray.dtype, value }
|
||||
}
|
||||
|
||||
pub fn get_index<G: CodeGenerator + ?Sized>(
|
||||
|
@ -7,18 +7,7 @@ use crate::{
|
||||
|
||||
use super::NDArrayObject;
|
||||
|
||||
/// An LLVM numpy scalar with its [`Type`].
|
||||
///
|
||||
/// Intended to be used with [`ScalarOrNDArray`].
|
||||
///
|
||||
/// A scalar does not have to be an actual number. It could be arbitrary objects.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ScalarObject<'ctx> {
|
||||
pub dtype: Type,
|
||||
pub value: BasicValueEnum<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarObject<'ctx> {
|
||||
impl<'ctx> AnyObject<'ctx> {
|
||||
/// Promote this scalar to an unsized ndarray (like doing `np.asarray`).
|
||||
///
|
||||
/// The scalar value is allocated onto the stack, and the ndarray's `data` will point to that
|
||||
@ -35,7 +24,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
ctx.builder.build_store(data, self.value).unwrap();
|
||||
let data = pbyte_model.pointer_cast(generator, ctx, data, "data");
|
||||
|
||||
let ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, 0, "scalar_ndarray");
|
||||
let ndarray = NDArrayObject::alloca(generator, ctx, self.ty, 0, "scalar_ndarray");
|
||||
ndarray.instance.set(ctx, |f| f.data, data);
|
||||
ndarray
|
||||
}
|
||||
@ -44,7 +33,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||
/// A convenience enum for implementing scalar/ndarray agnostic utilities.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum ScalarOrNDArray<'ctx> {
|
||||
Scalar(ScalarObject<'ctx>),
|
||||
Scalar(AnyObject<'ctx>),
|
||||
NDArray(NDArrayObject<'ctx>),
|
||||
}
|
||||
|
||||
@ -59,7 +48,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn into_scalar(&self) -> ScalarObject<'ctx> {
|
||||
pub fn into_scalar(&self) -> AnyObject<'ctx> {
|
||||
match self {
|
||||
ScalarOrNDArray::NDArray(_ndarray) => panic!("Got NDArray"),
|
||||
ScalarOrNDArray::Scalar(scalar) => *scalar,
|
||||
@ -91,13 +80,13 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
#[must_use]
|
||||
pub fn dtype(&self) -> Type {
|
||||
match self {
|
||||
ScalarOrNDArray::Scalar(scalar) => scalar.dtype,
|
||||
ScalarOrNDArray::Scalar(scalar) => scalar.ty,
|
||||
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> {
|
||||
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for AnyObject<'ctx> {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
||||
@ -135,7 +124,7 @@ pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
|
||||
ScalarOrNDArray::NDArray(ndarray)
|
||||
}
|
||||
_ => {
|
||||
let scalar = ScalarObject { dtype: object.ty, value: object.value };
|
||||
let scalar = AnyObject { ty: object.ty, value: object.value };
|
||||
ScalarOrNDArray::Scalar(scalar)
|
||||
}
|
||||
}
|
||||
|
@ -1143,12 +1143,12 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
ret_dtype,
|
||||
|generator, ctx, scalar| {
|
||||
let result = match prim {
|
||||
PrimDef::FunInt32 => scalar.cast_to_int32(generator, ctx),
|
||||
PrimDef::FunInt64 => scalar.cast_to_int64(generator, ctx),
|
||||
PrimDef::FunUInt32 => scalar.cast_to_uint32(generator, ctx),
|
||||
PrimDef::FunUInt64 => scalar.cast_to_uint64(generator, ctx),
|
||||
PrimDef::FunFloat => scalar.cast_to_float(ctx),
|
||||
PrimDef::FunBool => scalar.cast_to_bool(ctx),
|
||||
PrimDef::FunInt32 => scalar.call_int32(generator, ctx),
|
||||
PrimDef::FunInt64 => scalar.call_int64(generator, ctx),
|
||||
PrimDef::FunUInt32 => scalar.call_uint32(generator, ctx),
|
||||
PrimDef::FunUInt64 => scalar.call_uint64(generator, ctx),
|
||||
PrimDef::FunFloat => scalar.call_float(ctx),
|
||||
PrimDef::FunBool => scalar.call_bool(ctx),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(result.value)
|
||||
@ -1277,7 +1277,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
ctx,
|
||||
int_sized,
|
||||
|generator, ctx, scalar| {
|
||||
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value)
|
||||
Ok(scalar.call_floor_or_ceil(generator, ctx, kind, int_sized).value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
@ -1927,7 +1927,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
let arg = AnyObject { value: arg, ty: arg_ty };
|
||||
|
||||
Ok(Some(arg.len(generator, ctx).value.as_basic_value_enum()))
|
||||
Ok(Some(arg.call_len(generator, ctx).value.as_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
Loading…
Reference in New Issue
Block a user