forked from M-Labs/nac3
WIP: core/ndstrides: remove ScalarObject
This commit is contained in:
parent
f8b934096d
commit
4b765cfb27
|
@ -1,18 +1,71 @@
|
||||||
use inkwell::values::BasicValueEnum;
|
use inkwell::{
|
||||||
|
values::{BasicValue, BasicValueEnum, FloatValue, IntValue},
|
||||||
|
FloatPredicate, IntPredicate,
|
||||||
|
};
|
||||||
|
use itertools::Itertools;
|
||||||
use list::ListObject;
|
use list::ListObject;
|
||||||
use ndarray::NDArrayObject;
|
use ndarray::{NDArrayObject, NDArrayOut};
|
||||||
use range::RangeObject;
|
use range::RangeObject;
|
||||||
use tuple::TupleObject;
|
use tuple::TupleObject;
|
||||||
|
|
||||||
use crate::typecheck::typedef::{Type, TypeEnum};
|
use crate::{
|
||||||
|
toplevel::helper::PrimDef,
|
||||||
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
|
};
|
||||||
|
|
||||||
use super::{model::*, CodeGenContext, CodeGenerator};
|
use super::{llvm_intrinsics, model::*, CodeGenContext, CodeGenerator};
|
||||||
|
|
||||||
pub mod list;
|
pub mod list;
|
||||||
pub mod ndarray;
|
pub mod ndarray;
|
||||||
pub mod range;
|
pub mod range;
|
||||||
pub mod tuple;
|
pub mod tuple;
|
||||||
|
|
||||||
|
/// Convenience function to crash the program when types of arguments are not supported.
|
||||||
|
/// Used to be debugged with a stacktrace.
|
||||||
|
fn unsupported_type<I>(ctx: &CodeGenContext<'_, '_>, tys: I) -> !
|
||||||
|
where
|
||||||
|
I: IntoIterator<Item = Type>,
|
||||||
|
{
|
||||||
|
unreachable!(
|
||||||
|
"unsupported types found '{}'",
|
||||||
|
tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum FloorOrCeil {
|
||||||
|
Floor,
|
||||||
|
Ceil,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum MinOrMax {
|
||||||
|
Min,
|
||||||
|
Max,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||||
|
vec![ctx.primitives.int32, ctx.primitives.int64]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||||
|
vec![ctx.primitives.uint32, ctx.primitives.uint64]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||||
|
vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
||||||
|
vec![
|
||||||
|
ctx.primitives.bool,
|
||||||
|
ctx.primitives.int32,
|
||||||
|
ctx.primitives.int64,
|
||||||
|
ctx.primitives.uint32,
|
||||||
|
ctx.primitives.uint64,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct AnyObject<'ctx> {
|
pub struct AnyObject<'ctx> {
|
||||||
pub ty: Type,
|
pub ty: Type,
|
||||||
|
@ -20,13 +73,260 @@ pub struct AnyObject<'ctx> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> AnyObject<'ctx> {
|
impl<'ctx> AnyObject<'ctx> {
|
||||||
// Get the `len()` of this object.
|
/// Returns true if this object's type is a [`TypeEnum::TObj`] and has the object ID as `prim`.
|
||||||
pub fn len<G: CodeGenerator + ?Sized>(
|
pub fn is_obj(&self, ctx: &mut CodeGenContext<'ctx, '_>, prim: PrimDef) -> bool {
|
||||||
|
match &*ctx.unifier.get_ty(self.ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. } => *obj_id == prim.id(),
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this object's type is a [`TypeEnum::TTuple`]
|
||||||
|
pub fn is_tuple(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
|
matches!(&*ctx.unifier.get_ty(self.ty), TypeEnum::TTuple { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
|
ctx.unifier.unioned(self.ty, ctx.primitives.int32)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_uint32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
|
ctx.unifier.unioned(self.ty, ctx.primitives.uint32)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_int64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
|
ctx.unifier.unioned(self.ty, ctx.primitives.int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_uint64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
|
ctx.unifier.unioned(self.ty, ctx.primitives.uint64)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
|
ctx.unifier.unioned(self.ty, ctx.primitives.bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_int_like(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
|
ctx.unifier.unioned_any(self.ty, int_like(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_signed_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
|
ctx.unifier.unioned_any(self.ty, signed_ints(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_unsigned_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
|
ctx.unifier.unioned_any(self.ty, unsigned_ints(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience function. If object has type `int32`, `int64`, `uint32`, `uint64`, or `bool`,
|
||||||
|
/// get its underlying LLVM value.
|
||||||
|
///
|
||||||
|
/// Panic if the type is wrong.
|
||||||
|
pub fn into_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||||
|
if self.is_int_like(ctx) {
|
||||||
|
self.value.into_int_value()
|
||||||
|
} else {
|
||||||
|
panic!("not an int32 type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
|
self.is_obj(ctx, PrimDef::Float)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience function. If object has type `float`, get its underlying LLVM value.
|
||||||
|
///
|
||||||
|
/// Panic if the type is wrong.
|
||||||
|
pub fn into_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Float<'ctx, Float64> {
|
||||||
|
if self.is_float(ctx) {
|
||||||
|
// self.value must be a FloatValue
|
||||||
|
FloatModel(Float64).believe_value(self.value.into_float_value())
|
||||||
|
} else {
|
||||||
|
panic!("not a float type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_ndarray(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
|
self.is_obj(ctx, PrimDef::NDArray)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_ndarray<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
) -> Int<'ctx, Int32> {
|
) -> NDArrayObject<'ctx> {
|
||||||
match &*ctx.unifier.get_ty_immutable(self.ty) {
|
NDArrayObject::from_object(generator, ctx, *self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function to compare two scalars.
|
||||||
|
///
|
||||||
|
/// Only int-to-int and float-to-float comparisons are allowed.
|
||||||
|
///
|
||||||
|
/// Panic otherwise.
|
||||||
|
pub fn compare_int_or_float_by_predicate<G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
lhs: AnyObject<'ctx>,
|
||||||
|
rhs: AnyObject<'ctx>,
|
||||||
|
int_predicate: IntPredicate,
|
||||||
|
float_predicate: FloatPredicate,
|
||||||
|
name: &str,
|
||||||
|
) -> Int<'ctx, Bool> {
|
||||||
|
if !ctx.unifier.unioned(lhs.ty, rhs.ty) {
|
||||||
|
panic!("lhs and rhs type are not the same.")
|
||||||
|
}
|
||||||
|
|
||||||
|
let bool_model = IntModel(Bool);
|
||||||
|
|
||||||
|
let common_ty = lhs.ty;
|
||||||
|
let result = if lhs.is_float(ctx) {
|
||||||
|
let lhs = lhs.into_float(ctx);
|
||||||
|
let rhs = rhs.into_float(ctx);
|
||||||
|
ctx.builder.build_float_compare(float_predicate, lhs.value, rhs.value, name).unwrap()
|
||||||
|
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
|
||||||
|
let lhs = lhs.into_int(ctx);
|
||||||
|
let rhs = rhs.into_int(ctx);
|
||||||
|
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, [lhs.ty, rhs.ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
bool_model.check_value(generator, ctx.ctx, result).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper function for `int32()`, `int64()`, `uint32()`, and `uint64()`.
|
||||||
|
///
|
||||||
|
/// TODO: Document me
|
||||||
|
fn cast_to_int_conversion<'a, G, HandleFloatFn>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
ret_int_ty: Type,
|
||||||
|
handle_float: HandleFloatFn,
|
||||||
|
) -> AnyObject<'ctx>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
HandleFloatFn:
|
||||||
|
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
|
||||||
|
{
|
||||||
|
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
|
||||||
|
|
||||||
|
let result = if self.is_float(ctx) {
|
||||||
|
// Handle float to int
|
||||||
|
let n = self.into_float(ctx);
|
||||||
|
handle_float(generator, ctx, n.value)
|
||||||
|
} else if self.is_int_like(ctx) {
|
||||||
|
// Handle int to a new int type
|
||||||
|
let n = self.into_int(ctx);
|
||||||
|
if n.get_type().get_bit_width() <= ret_int_ty_llvm.get_bit_width() {
|
||||||
|
ctx.builder.build_int_z_extend(n, ret_int_ty_llvm, "zext").unwrap()
|
||||||
|
} else {
|
||||||
|
ctx.builder.build_int_truncate(n, ret_int_ty_llvm, "trunc").unwrap()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, [self.ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(ret_int_ty_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
|
||||||
|
AnyObject { value: result.into(), ty: ret_int_ty }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call `int32()` on this object.
|
||||||
|
#[must_use]
|
||||||
|
pub fn call_int32<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> AnyObject<'ctx> {
|
||||||
|
self.cast_to_int_conversion(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ctx.primitives.int32,
|
||||||
|
|_generator, ctx, input| {
|
||||||
|
let n =
|
||||||
|
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap();
|
||||||
|
ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call `int64()` on this object.
|
||||||
|
#[must_use]
|
||||||
|
pub fn call_int64<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> AnyObject<'ctx> {
|
||||||
|
self.cast_to_int_conversion(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ctx.primitives.int64,
|
||||||
|
|_generator, ctx, input| {
|
||||||
|
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call `uint32()` on this object.
|
||||||
|
#[must_use]
|
||||||
|
pub fn call_uint32<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> AnyObject<'ctx> {
|
||||||
|
self.cast_to_int_conversion(generator, ctx, ctx.primitives.uint32, |_generator, ctx, n| {
|
||||||
|
let n_gez = ctx
|
||||||
|
.builder
|
||||||
|
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let to_int32 =
|
||||||
|
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap();
|
||||||
|
let to_uint64 =
|
||||||
|
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||||
|
|
||||||
|
ctx.builder
|
||||||
|
.build_select(
|
||||||
|
n_gez,
|
||||||
|
ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(),
|
||||||
|
to_int32,
|
||||||
|
"conv",
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.into_int_value()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call `uint64()` on this object.
|
||||||
|
#[must_use]
|
||||||
|
pub fn call_uint64<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> AnyObject<'ctx> {
|
||||||
|
self.cast_to_int_conversion(generator, ctx, ctx.primitives.uint64, |_generator, ctx, n| {
|
||||||
|
let val_gez = ctx
|
||||||
|
.builder
|
||||||
|
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let to_int64 =
|
||||||
|
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||||
|
let to_uint64 =
|
||||||
|
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
||||||
|
|
||||||
|
ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap().into_int_value()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the `len()` of this object.
|
||||||
|
pub fn call_len<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> AnyObject<'ctx> {
|
||||||
|
// TODO: Switch to returning SizeT
|
||||||
|
let result = match &*ctx.unifier.get_ty_immutable(self.ty) {
|
||||||
TypeEnum::TTuple { .. } => {
|
TypeEnum::TTuple { .. } => {
|
||||||
let tuple = TupleObject::from_object(ctx, *self);
|
let tuple = TupleObject::from_object(ctx, *self);
|
||||||
tuple.len(generator, ctx).truncate(generator, ctx, Int32, "tuple_len_32")
|
tuple.len(generator, ctx).truncate(generator, ctx, Int32, "tuple_len_32")
|
||||||
|
@ -50,6 +350,259 @@ impl<'ctx> AnyObject<'ctx> {
|
||||||
ndarray.len(generator, ctx).truncate(generator, ctx, Int32, "ndarray_len_i32")
|
ndarray.len(generator, ctx).truncate(generator, ctx, Int32, "ndarray_len_i32")
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
AnyObject { ty: ctx.primitives.int32, value: result.value.as_basic_value_enum() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Like [`AnyObject::call_bool`] but this returns an `Int<'ctx, Bool>` instead of an object.
|
||||||
|
pub fn bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Int<'ctx, Bool> {
|
||||||
|
let bool_model = IntModel(Bool);
|
||||||
|
if self.is_int_like(ctx) {
|
||||||
|
let n = self.into_int(ctx);
|
||||||
|
let n = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
|
||||||
|
.unwrap();
|
||||||
|
bool_model.believe_value(n)
|
||||||
|
} else if self.is_float(ctx) {
|
||||||
|
let n = self.value.into_float_value();
|
||||||
|
let n = ctx
|
||||||
|
.builder
|
||||||
|
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
|
||||||
|
.unwrap();
|
||||||
|
bool_model.believe_value(n)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, [self.ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call `bool()` on this object.
|
||||||
|
pub fn call_bool<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> AnyObject<'ctx> {
|
||||||
|
let n = self.bool(ctx);
|
||||||
|
|
||||||
|
// NAC3 booleans are i8
|
||||||
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
|
let n = ctx.builder.build_int_z_extend(n.value, llvm_i8, "bool").unwrap();
|
||||||
|
|
||||||
|
AnyObject { ty: ctx.primitives.bool, value: n.as_basic_value_enum() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call `float()` on this object.
|
||||||
|
pub fn call_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
|
||||||
|
let f64_model = FloatModel(Float64);
|
||||||
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
|
||||||
|
let result = if self.is_float(ctx) {
|
||||||
|
self.into_float(ctx)
|
||||||
|
} else if self.is_signed_int(ctx) || self.is_bool(ctx) {
|
||||||
|
let n = self.into_int(ctx);
|
||||||
|
let n = ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap();
|
||||||
|
f64_model.believe_value(n)
|
||||||
|
} else if self.is_unsigned_int(ctx) {
|
||||||
|
let n = self.into_int(ctx);
|
||||||
|
let n = ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap();
|
||||||
|
f64_model.believe_value(n)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, [self.ty]);
|
||||||
|
};
|
||||||
|
|
||||||
|
AnyObject { ty: ctx.primitives.float, value: result.value.as_basic_value_enum() }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call `abs()` on this object.
|
||||||
|
#[must_use]
|
||||||
|
pub fn call_abs<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> Self {
|
||||||
|
if self.is_float(ctx) {
|
||||||
|
let n = self.value.into_float_value();
|
||||||
|
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
|
||||||
|
AnyObject { value: n.into(), ty: ctx.primitives.float }
|
||||||
|
} else if self.is_unsigned_int(ctx) || self.is_signed_int(ctx) {
|
||||||
|
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
|
||||||
|
|
||||||
|
let n = self.value.into_int_value();
|
||||||
|
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
|
||||||
|
AnyObject { value: n.into(), ty: self.ty }
|
||||||
|
} else if self.is_ndarray(ctx) {
|
||||||
|
let ndarray = self.into_ndarray(generator, ctx);
|
||||||
|
ndarray
|
||||||
|
.map(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|
||||||
|
|generator, ctx, scalar| Ok(scalar.call_abs(generator, ctx)),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.to_any_object(ctx)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, [self.ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call `round()` on this object.
|
||||||
|
//
|
||||||
|
// It is possible to specify which kind of int type to return.
|
||||||
|
pub fn call_round<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ret_int_ty: Type,
|
||||||
|
) -> AnyObject<'ctx> {
|
||||||
|
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
|
||||||
|
|
||||||
|
let result = if ctx.unifier.unioned(self.ty, ctx.primitives.float) {
|
||||||
|
let n = self.value.into_float_value();
|
||||||
|
let n = llvm_intrinsics::call_float_round(ctx, n, None);
|
||||||
|
ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, "round").unwrap()
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, [self.ty])
|
||||||
|
};
|
||||||
|
AnyObject { ty: ret_int_ty, value: result.as_basic_value_enum() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call `np_round()` on this object.
|
||||||
|
///
|
||||||
|
/// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type.
|
||||||
|
#[must_use]
|
||||||
|
pub fn call_np_round<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
) -> AnyObject<'ctx> {
|
||||||
|
if self.is_float(ctx) {
|
||||||
|
let n = self.into_float(ctx);
|
||||||
|
let n = llvm_intrinsics::call_float_rint(ctx, n.value, None);
|
||||||
|
AnyObject { ty: ctx.primitives.float, value: n.as_basic_value_enum() }
|
||||||
|
} else if self.is_ndarray(ctx) {
|
||||||
|
let ndarray = self.into_ndarray(generator, ctx);
|
||||||
|
ndarray
|
||||||
|
.map(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|
||||||
|
|generator, ctx, scalar| Ok(scalar.call_np_round(generator, ctx)),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.to_any_object(ctx)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, [self.ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call `min()` or `max()` on two objects.
|
||||||
|
pub fn call_min_or_max(
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
kind: MinOrMax,
|
||||||
|
a: AnyObject<'ctx>,
|
||||||
|
b: AnyObject<'ctx>,
|
||||||
|
) -> AnyObject<'ctx> {
|
||||||
|
if !ctx.unifier.unioned(a.ty, b.ty) {
|
||||||
|
unsupported_type(ctx, [a.ty, b.ty])
|
||||||
|
}
|
||||||
|
|
||||||
|
let common_ty = a.ty;
|
||||||
|
|
||||||
|
if a.is_float(ctx) {
|
||||||
|
let function = match kind {
|
||||||
|
MinOrMax::Min => llvm_intrinsics::call_float_minnum,
|
||||||
|
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
|
||||||
|
};
|
||||||
|
|
||||||
|
let a = a.into_float(ctx).value;
|
||||||
|
let b = b.into_float(ctx).value;
|
||||||
|
let result = function(ctx, a, b, None);
|
||||||
|
AnyObject { value: result.as_basic_value_enum(), ty: ctx.primitives.float }
|
||||||
|
} else if a.is_unsigned_int(ctx) || a.is_bool(ctx) {
|
||||||
|
// Treating bool has an unsigned int since that is convenient
|
||||||
|
let function = match kind {
|
||||||
|
MinOrMax::Min => llvm_intrinsics::call_int_umin,
|
||||||
|
MinOrMax::Max => llvm_intrinsics::call_int_umax,
|
||||||
|
};
|
||||||
|
|
||||||
|
let a = a.into_int(ctx);
|
||||||
|
let b = b.into_int(ctx);
|
||||||
|
let result = function(ctx, a, b, None);
|
||||||
|
AnyObject { value: result.as_basic_value_enum(), ty: common_ty }
|
||||||
|
} else if a.is_signed_int(ctx) {
|
||||||
|
let function = match kind {
|
||||||
|
MinOrMax::Min => llvm_intrinsics::call_int_smin,
|
||||||
|
MinOrMax::Max => llvm_intrinsics::call_int_smax,
|
||||||
|
};
|
||||||
|
|
||||||
|
let a = a.into_int(ctx);
|
||||||
|
let b = b.into_int(ctx);
|
||||||
|
let result = function(ctx, a, b, None);
|
||||||
|
AnyObject { value: result.as_basic_value_enum(), ty: common_ty }
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, [common_ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call `floor()` or `ceil()` on this object.
|
||||||
|
///
|
||||||
|
/// It is possible to specify which kind of int type to return.
|
||||||
|
#[must_use]
|
||||||
|
pub fn call_floor_or_ceil<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
kind: FloorOrCeil,
|
||||||
|
ret_int_ty: Type,
|
||||||
|
) -> Self {
|
||||||
|
let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type();
|
||||||
|
|
||||||
|
if self.is_float(ctx) {
|
||||||
|
let function = match kind {
|
||||||
|
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
|
||||||
|
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
|
||||||
|
};
|
||||||
|
|
||||||
|
let n = self.into_float(ctx).value;
|
||||||
|
let n = function(ctx, n, None);
|
||||||
|
let n = ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, "").unwrap();
|
||||||
|
AnyObject { ty: ret_int_ty, value: n.as_basic_value_enum() }
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, [self.ty])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Call `np_floor()` or `np_ceil()` on this object.
|
||||||
|
#[must_use]
|
||||||
|
pub fn call_np_floor_or_ceil<G: CodeGenerator + ?Sized>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
kind: FloorOrCeil,
|
||||||
|
) -> Self {
|
||||||
|
// TODO:
|
||||||
|
if self.is_float(ctx) {
|
||||||
|
let function = match kind {
|
||||||
|
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
|
||||||
|
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
|
||||||
|
};
|
||||||
|
let n = self.into_float(ctx).value;
|
||||||
|
let n = function(ctx, n, None);
|
||||||
|
AnyObject { ty: ctx.primitives.float, value: n.as_basic_value_enum() }
|
||||||
|
} else if self.is_ndarray(ctx) {
|
||||||
|
let ndarray = self.into_ndarray(generator, ctx);
|
||||||
|
ndarray
|
||||||
|
.map(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
NDArrayOut::NewNDArray { dtype: ctx.primitives.float },
|
||||||
|
|generator, ctx, scalar| Ok(scalar.call_np_floor_or_ceil(generator, ctx, kind)),
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.to_any_object(ctx)
|
||||||
|
} else {
|
||||||
|
unsupported_type(ctx, [self.ty])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
use inkwell::{values::BasicValueEnum, IntPredicate};
|
use inkwell::{values::BasicValueEnum, IntPredicate};
|
||||||
|
|
||||||
use super::{scalar::ScalarObject, NDArrayObject};
|
use super::NDArrayObject;
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, CodeGenContext,
|
irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, object::AnyObject,
|
||||||
CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
typecheck::typedef::Type,
|
typecheck::typedef::Type,
|
||||||
};
|
};
|
||||||
|
@ -93,10 +93,10 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
dtype: Type,
|
dtype: Type,
|
||||||
ndims: u64,
|
ndims: u64,
|
||||||
shape: Ptr<'ctx, IntModel<SizeT>>,
|
shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||||
fill_value: ScalarObject<'ctx>,
|
fill_value: AnyObject<'ctx>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Sanity check on `fill_value`'s dtype.
|
// Sanity check on `fill_value`'s dtype.
|
||||||
assert!(ctx.unifier.unioned(dtype, fill_value.dtype));
|
assert!(ctx.unifier.unioned(dtype, fill_value.ty));
|
||||||
|
|
||||||
let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape);
|
let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape);
|
||||||
ndarray.fill(generator, ctx, fill_value);
|
ndarray.fill(generator, ctx, fill_value);
|
||||||
|
@ -112,7 +112,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
shape: Ptr<'ctx, IntModel<SizeT>>,
|
shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let fill_value = ndarray_zero_value(generator, ctx, dtype);
|
let fill_value = ndarray_zero_value(generator, ctx, dtype);
|
||||||
let fill_value = ScalarObject { value: fill_value, dtype };
|
let fill_value = AnyObject { value: fill_value, ty: dtype };
|
||||||
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
|
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,7 +125,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
shape: Ptr<'ctx, IntModel<SizeT>>,
|
shape: Ptr<'ctx, IntModel<SizeT>>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
let fill_value = ndarray_one_value(generator, ctx, dtype);
|
||||||
let fill_value = ScalarObject { value: fill_value, dtype };
|
let fill_value = AnyObject { value: fill_value, ty: dtype };
|
||||||
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
|
NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,443 +1,13 @@
|
||||||
use inkwell::{
|
use inkwell::{FloatPredicate, IntPredicate};
|
||||||
values::{BasicValue, FloatValue, IntValue},
|
|
||||||
FloatPredicate, IntPredicate,
|
|
||||||
};
|
|
||||||
use itertools::Itertools;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::codegen::{
|
||||||
codegen::{llvm_intrinsics, model::*, stmt::gen_if_callback, CodeGenContext, CodeGenerator},
|
model::*,
|
||||||
typecheck::typedef::Type,
|
object::{AnyObject, MinOrMax},
|
||||||
|
stmt::gen_if_callback,
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{scalar::ScalarObject, NDArrayObject};
|
use super::NDArrayObject;
|
||||||
|
|
||||||
/// Convenience function to crash the program when types of arguments are not supported.
|
|
||||||
/// Used to be debugged with a stacktrace.
|
|
||||||
fn unsupported_type<I>(ctx: &CodeGenContext<'_, '_>, tys: I) -> !
|
|
||||||
where
|
|
||||||
I: IntoIterator<Item = Type>,
|
|
||||||
{
|
|
||||||
unreachable!(
|
|
||||||
"unsupported types found '{}'",
|
|
||||||
tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub enum FloorOrCeil {
|
|
||||||
Floor,
|
|
||||||
Ceil,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub enum MinOrMax {
|
|
||||||
Min,
|
|
||||||
Max,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
|
||||||
vec![ctx.primitives.int32, ctx.primitives.int64]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
|
||||||
vec![ctx.primitives.uint32, ctx.primitives.uint64]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
|
||||||
vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec<Type> {
|
|
||||||
vec![
|
|
||||||
ctx.primitives.bool,
|
|
||||||
ctx.primitives.int32,
|
|
||||||
ctx.primitives.int64,
|
|
||||||
ctx.primitives.uint32,
|
|
||||||
ctx.primitives.uint64,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn cast_to_int_conversion<'ctx, 'a, G, HandleFloatFn>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
|
||||||
scalar: ScalarObject<'ctx>,
|
|
||||||
ret_int_dtype: Type,
|
|
||||||
handle_float: HandleFloatFn,
|
|
||||||
) -> ScalarObject<'ctx>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
HandleFloatFn:
|
|
||||||
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
|
|
||||||
{
|
|
||||||
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
|
|
||||||
|
|
||||||
let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) {
|
|
||||||
// Special handling for floats
|
|
||||||
let n = scalar.value.into_float_value();
|
|
||||||
handle_float(generator, ctx, n)
|
|
||||||
} else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) {
|
|
||||||
let n = scalar.value.into_int_value();
|
|
||||||
|
|
||||||
if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() {
|
|
||||||
ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap()
|
|
||||||
} else {
|
|
||||||
ctx.builder.build_int_truncate(n, ret_int_dtype_llvm, "trunc").unwrap()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, [scalar.dtype]);
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
|
|
||||||
ScalarObject { value: result.into(), dtype: ret_int_dtype }
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> ScalarObject<'ctx> {
|
|
||||||
/// Convenience function. Assume this scalar has typechecker type float64, get its underlying LLVM value.
|
|
||||||
///
|
|
||||||
/// Panic if the type is wrong.
|
|
||||||
pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> {
|
|
||||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
|
||||||
self.value.into_float_value() // self.value must be a FloatValue
|
|
||||||
} else {
|
|
||||||
panic!("not a float type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convenience function. Assume this scalar has typechecker type int32, get its underlying LLVM value.
|
|
||||||
///
|
|
||||||
/// Panic if the type is wrong.
|
|
||||||
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
|
||||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) {
|
|
||||||
let value = self.value.into_int_value();
|
|
||||||
debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check
|
|
||||||
value
|
|
||||||
} else {
|
|
||||||
panic!("not a float type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Compare two scalars. Only int-to-int and float-to-float comparisons are allowed.
|
|
||||||
/// Panic otherwise.
|
|
||||||
pub fn compare<G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
lhs: ScalarObject<'ctx>,
|
|
||||||
rhs: ScalarObject<'ctx>,
|
|
||||||
int_predicate: IntPredicate,
|
|
||||||
float_predicate: FloatPredicate,
|
|
||||||
name: &str,
|
|
||||||
) -> Int<'ctx, Bool> {
|
|
||||||
if !ctx.unifier.unioned(lhs.dtype, rhs.dtype) {
|
|
||||||
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
|
|
||||||
}
|
|
||||||
|
|
||||||
let bool_model = IntModel(Bool);
|
|
||||||
|
|
||||||
let common_ty = lhs.dtype;
|
|
||||||
let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) {
|
|
||||||
let lhs = lhs.value.into_float_value();
|
|
||||||
let rhs = rhs.value.into_float_value();
|
|
||||||
ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap()
|
|
||||||
} else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) {
|
|
||||||
let lhs = lhs.value.into_int_value();
|
|
||||||
let rhs = rhs.value.into_int_value();
|
|
||||||
ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap()
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, [lhs.dtype, rhs.dtype]);
|
|
||||||
};
|
|
||||||
|
|
||||||
bool_model.check_value(generator, ctx.ctx, result).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `int32()`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn cast_to_int32<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Self {
|
|
||||||
cast_to_int_conversion(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
*self,
|
|
||||||
ctx.primitives.int32,
|
|
||||||
|_generator, ctx, input| {
|
|
||||||
let n =
|
|
||||||
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap();
|
|
||||||
ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `int64()`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn cast_to_int64<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Self {
|
|
||||||
cast_to_int_conversion(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
*self,
|
|
||||||
ctx.primitives.int64,
|
|
||||||
|_generator, ctx, input| {
|
|
||||||
ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `uint32()`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn cast_to_uint32<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Self {
|
|
||||||
cast_to_int_conversion(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
*self,
|
|
||||||
ctx.primitives.uint32,
|
|
||||||
|_generator, ctx, n| {
|
|
||||||
let n_gez = ctx
|
|
||||||
.builder
|
|
||||||
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let to_int32 =
|
|
||||||
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap();
|
|
||||||
let to_uint64 =
|
|
||||||
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_select(
|
|
||||||
n_gez,
|
|
||||||
ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(),
|
|
||||||
to_int32,
|
|
||||||
"conv",
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
.into_int_value()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `uint64()`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn cast_to_uint64<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> Self {
|
|
||||||
cast_to_int_conversion(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
*self,
|
|
||||||
ctx.primitives.uint64,
|
|
||||||
|_generator, ctx, n| {
|
|
||||||
let val_gez = ctx
|
|
||||||
.builder
|
|
||||||
.build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "")
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let to_int64 =
|
|
||||||
ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap();
|
|
||||||
let to_uint64 =
|
|
||||||
ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap();
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_select(val_gez, to_uint64, to_int64, "conv")
|
|
||||||
.unwrap()
|
|
||||||
.into_int_value()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `bool()`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
|
||||||
// TODO: Why is the original code being so lax about i1 and i8 for the returned int type?
|
|
||||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) {
|
|
||||||
self.value.into_int_value()
|
|
||||||
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
|
|
||||||
let n = self.value.into_int_value();
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool")
|
|
||||||
.unwrap()
|
|
||||||
} else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
|
||||||
let n = self.value.into_float_value();
|
|
||||||
ctx.builder
|
|
||||||
.build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool")
|
|
||||||
.unwrap()
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, [self.dtype])
|
|
||||||
};
|
|
||||||
|
|
||||||
ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `float()`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn cast_to_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
|
||||||
|
|
||||||
let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
|
||||||
self.value.into_float_value()
|
|
||||||
} else if ctx
|
|
||||||
.unifier
|
|
||||||
.unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat())
|
|
||||||
{
|
|
||||||
let n = self.value.into_int_value();
|
|
||||||
ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap()
|
|
||||||
} else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) {
|
|
||||||
let n = self.value.into_int_value();
|
|
||||||
ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap()
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, [self.dtype]);
|
|
||||||
};
|
|
||||||
|
|
||||||
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `round()`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn round<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
ret_int_dtype: Type,
|
|
||||||
) -> Self {
|
|
||||||
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
|
|
||||||
|
|
||||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
|
||||||
let n = self.value.into_float_value();
|
|
||||||
let n = llvm_intrinsics::call_float_round(ctx, n, None);
|
|
||||||
ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap()
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, [self.dtype, ret_int_dtype])
|
|
||||||
};
|
|
||||||
ScalarObject { dtype: ret_int_dtype, value: result.as_basic_value_enum() }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `np_round()`.
|
|
||||||
///
|
|
||||||
/// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type.
|
|
||||||
#[must_use]
|
|
||||||
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
|
||||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
|
||||||
let n = self.value.into_float_value();
|
|
||||||
llvm_intrinsics::call_float_rint(ctx, n, None)
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, [self.dtype])
|
|
||||||
};
|
|
||||||
ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `min()` or `max()`.
|
|
||||||
pub fn min_or_max(
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
kind: MinOrMax,
|
|
||||||
a: Self,
|
|
||||||
b: Self,
|
|
||||||
) -> Self {
|
|
||||||
if !ctx.unifier.unioned(a.dtype, b.dtype) {
|
|
||||||
unsupported_type(ctx, [a.dtype, b.dtype])
|
|
||||||
}
|
|
||||||
|
|
||||||
let common_dtype = a.dtype;
|
|
||||||
|
|
||||||
if ctx.unifier.unioned(common_dtype, ctx.primitives.float) {
|
|
||||||
let function = match kind {
|
|
||||||
MinOrMax::Min => llvm_intrinsics::call_float_minnum,
|
|
||||||
MinOrMax::Max => llvm_intrinsics::call_float_maxnum,
|
|
||||||
};
|
|
||||||
let result =
|
|
||||||
function(ctx, a.value.into_float_value(), b.value.into_float_value(), None);
|
|
||||||
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
|
|
||||||
} else if ctx.unifier.unioned_any(
|
|
||||||
common_dtype,
|
|
||||||
[unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(),
|
|
||||||
) {
|
|
||||||
// Treating bool has an unsigned int since that is convenient
|
|
||||||
let function = match kind {
|
|
||||||
MinOrMax::Min => llvm_intrinsics::call_int_umin,
|
|
||||||
MinOrMax::Max => llvm_intrinsics::call_int_umax,
|
|
||||||
};
|
|
||||||
let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None);
|
|
||||||
ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype }
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, [common_dtype])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `floor()` or `ceil()`.
|
|
||||||
///
|
|
||||||
/// * `ret_int_dtype` - The type of int to return.
|
|
||||||
///
|
|
||||||
/// Takes in a float/int and returns an int of type `ret_int_dtype`
|
|
||||||
#[must_use]
|
|
||||||
pub fn floor_or_ceil<G: CodeGenerator + ?Sized>(
|
|
||||||
&self,
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
kind: FloorOrCeil,
|
|
||||||
ret_int_dtype: Type,
|
|
||||||
) -> Self {
|
|
||||||
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
|
|
||||||
|
|
||||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
|
||||||
let function = match kind {
|
|
||||||
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
|
|
||||||
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
|
|
||||||
};
|
|
||||||
let n = self.value.into_float_value();
|
|
||||||
let n = function(ctx, n, None);
|
|
||||||
|
|
||||||
let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap();
|
|
||||||
ScalarObject { dtype: ret_int_dtype, value: n.as_basic_value_enum() }
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, [self.dtype])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `np_floor()`/ `np_ceil()`.
|
|
||||||
///
|
|
||||||
/// Takes in a float/int and returns a float64 result.
|
|
||||||
#[must_use]
|
|
||||||
pub fn np_floor_or_ceil(&self, ctx: &mut CodeGenContext<'ctx, '_>, kind: FloorOrCeil) -> Self {
|
|
||||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
|
||||||
let function = match kind {
|
|
||||||
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
|
|
||||||
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
|
|
||||||
};
|
|
||||||
let n = self.value.into_float_value();
|
|
||||||
let n = function(ctx, n, None);
|
|
||||||
ScalarObject { dtype: ctx.primitives.float, value: n.as_basic_value_enum() }
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, [self.dtype])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invoke NAC3's builtin `abs()`.
|
|
||||||
#[must_use]
|
|
||||||
pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
|
||||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
|
||||||
let n = self.value.into_float_value();
|
|
||||||
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
|
|
||||||
ScalarObject { value: n.into(), dtype: ctx.primitives.float }
|
|
||||||
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
|
|
||||||
let n = self.value.into_int_value();
|
|
||||||
|
|
||||||
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
|
|
||||||
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
|
|
||||||
|
|
||||||
ScalarObject { value: n.into(), dtype: self.dtype }
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, [self.dtype])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
/// Helper function to implement NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
|
/// Helper function to implement NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
|
||||||
|
@ -451,7 +21,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
kind: MinOrMax,
|
kind: MinOrMax,
|
||||||
on_empty_err_msg: &str,
|
on_empty_err_msg: &str,
|
||||||
) -> (ScalarObject<'ctx>, Int<'ctx, SizeT>) {
|
) -> (AnyObject<'ctx>, Int<'ctx, SizeT>) {
|
||||||
let sizet_model = IntModel(SizeT);
|
let sizet_model = IntModel(SizeT);
|
||||||
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
|
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
|
||||||
|
|
||||||
|
@ -478,17 +48,17 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
|
||||||
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||||
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
|
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
|
||||||
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
|
let old_extremum = AnyObject { ty: self.dtype, value: old_extremum };
|
||||||
|
|
||||||
let scalar = nditer.get_scalar(generator, ctx);
|
let scalar = nditer.get_scalar(generator, ctx);
|
||||||
let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar);
|
let new_extremum = AnyObject::call_min_or_max(ctx, kind, old_extremum, scalar);
|
||||||
|
|
||||||
gen_if_callback(
|
gen_if_callback(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
|generator, ctx| {
|
|generator, ctx| {
|
||||||
// Is new_extremum is more extreme than old_extremum?
|
// Is new_extremum is more extreme than old_extremum?
|
||||||
let cmp = ScalarObject::compare(
|
let cmp = AnyObject::compare_int_or_float_by_predicate(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
new_extremum,
|
new_extremum,
|
||||||
|
@ -517,7 +87,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
|
let extremum_index = pextremum_index.load(generator, ctx, "extremum_index");
|
||||||
|
|
||||||
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
|
let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap();
|
||||||
let extremum = ScalarObject { dtype: self.dtype, value: extremum };
|
let extremum = AnyObject { ty: self.dtype, value: extremum };
|
||||||
|
|
||||||
(extremum, extremum_index)
|
(extremum, extremum_index)
|
||||||
}
|
}
|
||||||
|
@ -528,7 +98,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
kind: MinOrMax,
|
kind: MinOrMax,
|
||||||
) -> ScalarObject<'ctx> {
|
) -> AnyObject<'ctx> {
|
||||||
let on_empty_err_msg = format!(
|
let on_empty_err_msg = format!(
|
||||||
"zero-size array to reduction operation {} which has no identity",
|
"zero-size array to reduction operation {} which has no identity",
|
||||||
match kind {
|
match kind {
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
use inkwell::values::BasicValueEnum;
|
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
object::ndarray::{NDArrayObject, ScalarObject},
|
object::ndarray::{AnyObject, NDArrayObject},
|
||||||
stmt::gen_for_callback,
|
stmt::gen_for_callback,
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
|
@ -27,8 +26,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
MappingFn: FnOnce(
|
MappingFn: FnOnce(
|
||||||
&mut G,
|
&mut G,
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
&[ScalarObject<'ctx>],
|
&[AnyObject<'ctx>],
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
) -> Result<AnyObject<'ctx>, String>,
|
||||||
{
|
{
|
||||||
// Broadcast inputs
|
// Broadcast inputs
|
||||||
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
|
let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays);
|
||||||
|
@ -89,9 +88,11 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
in_nditers.iter().map(|nditer| nditer.get_scalar(generator, ctx)).collect_vec();
|
in_nditers.iter().map(|nditer| nditer.get_scalar(generator, ctx)).collect_vec();
|
||||||
|
|
||||||
let result = mapping(generator, ctx, &in_scalars)?;
|
let result = mapping(generator, ctx, &in_scalars)?;
|
||||||
|
// Sanity check on result's ty
|
||||||
|
assert!(ctx.unifier.unioned(result.ty, out_ndarray.dtype));
|
||||||
|
|
||||||
let p = out_nditer.get_pointer(generator, ctx);
|
let p = out_nditer.get_pointer(generator, ctx);
|
||||||
ctx.builder.build_store(p, result).unwrap();
|
ctx.builder.build_store(p, result.value).unwrap();
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
},
|
},
|
||||||
|
@ -103,20 +104,6 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// let start = sizet_model.const_0(generator, ctx.ctx);
|
|
||||||
// let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`.
|
|
||||||
// let step = sizet_model.const_1(generator, ctx.ctx);
|
|
||||||
// gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| {
|
|
||||||
// let elements =
|
|
||||||
// ndarrays.iter().map(|ndarray| ndarray.get_nth_scalar(generator, ctx, i)).collect_vec();
|
|
||||||
|
|
||||||
// let ret = mapping(generator, ctx, i, &elements)?;
|
|
||||||
|
|
||||||
// let pret = out_ndarray.get_nth_pointer(generator, ctx, i, "pret");
|
|
||||||
// ctx.builder.build_store(pret, ret).unwrap();
|
|
||||||
// Ok(())
|
|
||||||
// })?;
|
|
||||||
|
|
||||||
Ok(out_ndarray)
|
Ok(out_ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,8 +119,8 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
Mapping: FnOnce(
|
Mapping: FnOnce(
|
||||||
&mut G,
|
&mut G,
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
ScalarObject<'ctx>,
|
AnyObject<'ctx>,
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
) -> Result<AnyObject<'ctx>, String>,
|
||||||
{
|
{
|
||||||
NDArrayObject::broadcasting_starmap(
|
NDArrayObject::broadcasting_starmap(
|
||||||
generator,
|
generator,
|
||||||
|
@ -160,16 +147,18 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
MappingFn: FnOnce(
|
MappingFn: FnOnce(
|
||||||
&mut G,
|
&mut G,
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
&[ScalarObject<'ctx>],
|
&[AnyObject<'ctx>],
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
) -> Result<AnyObject<'ctx>, String>,
|
||||||
{
|
{
|
||||||
// Check if all inputs are ScalarObjects
|
// Check if all inputs are AnyObjects
|
||||||
let all_scalars: Option<Vec<_>> =
|
let all_scalars: Option<Vec<_>> = inputs.iter().map(AnyObject::try_from).try_collect().ok();
|
||||||
inputs.iter().map(ScalarObject::try_from).try_collect().ok();
|
|
||||||
|
|
||||||
if let Some(scalars) = all_scalars {
|
if let Some(scalars) = all_scalars {
|
||||||
let scalar =
|
let scalar = mapping(generator, ctx, &scalars)?;
|
||||||
ScalarObject { value: mapping(generator, ctx, &scalars)?, dtype: ret_dtype };
|
|
||||||
|
// Sanity check on scalar's type
|
||||||
|
assert!(ctx.unifier.unioned(scalar.ty, ret_dtype));
|
||||||
|
|
||||||
Ok(ScalarOrNDArray::Scalar(scalar))
|
Ok(ScalarOrNDArray::Scalar(scalar))
|
||||||
} else {
|
} else {
|
||||||
// Promote all input to ndarrays and map through them.
|
// Promote all input to ndarrays and map through them.
|
||||||
|
@ -197,8 +186,8 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
Mapping: FnOnce(
|
Mapping: FnOnce(
|
||||||
&mut G,
|
&mut G,
|
||||||
&mut CodeGenContext<'ctx, 'a>,
|
&mut CodeGenContext<'ctx, 'a>,
|
||||||
ScalarObject<'ctx>,
|
AnyObject<'ctx>,
|
||||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
) -> Result<AnyObject<'ctx>, String>,
|
||||||
{
|
{
|
||||||
ScalarOrNDArray::broadcasting_starmap(
|
ScalarOrNDArray::broadcasting_starmap(
|
||||||
generator,
|
generator,
|
||||||
|
|
|
@ -38,7 +38,7 @@ use inkwell::{
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
use nditer::NDIterHandle;
|
use nditer::NDIterHandle;
|
||||||
use scalar::{ScalarObject, ScalarOrNDArray};
|
use scalar::ScalarOrNDArray;
|
||||||
use util::call_memcpy_model;
|
use util::call_memcpy_model;
|
||||||
|
|
||||||
use super::{tuple::TupleObject, AnyObject};
|
use super::{tuple::TupleObject, AnyObject};
|
||||||
|
@ -257,10 +257,10 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
nth: Int<'ctx, SizeT>,
|
nth: Int<'ctx, SizeT>,
|
||||||
) -> ScalarObject<'ctx> {
|
) -> AnyObject<'ctx> {
|
||||||
let p = self.get_nth_pointer(generator, ctx, nth, "value");
|
let p = self.get_nth_pointer(generator, ctx, nth, "value");
|
||||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||||
ScalarObject { dtype: self.dtype, value }
|
AnyObject { ty: self.dtype, value }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the n-th (0-based) scalar.
|
/// Set the n-th (0-based) scalar.
|
||||||
|
@ -271,10 +271,10 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
nth: Int<'ctx, SizeT>,
|
nth: Int<'ctx, SizeT>,
|
||||||
scalar: ScalarObject<'ctx>,
|
scalar: AnyObject<'ctx>,
|
||||||
) {
|
) {
|
||||||
// Sanity check on scalar's `dtype`
|
// Sanity check on scalar's `dtype`
|
||||||
assert!(ctx.unifier.unioned(scalar.dtype, self.dtype));
|
assert!(ctx.unifier.unioned(scalar.ty, self.dtype));
|
||||||
|
|
||||||
let pscalar = self.get_nth_pointer(generator, ctx, nth, "pscalar");
|
let pscalar = self.get_nth_pointer(generator, ctx, nth, "pscalar");
|
||||||
ctx.builder.build_store(pscalar, scalar.value).unwrap();
|
ctx.builder.build_store(pscalar, scalar.value).unwrap();
|
||||||
|
@ -284,7 +284,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
///
|
///
|
||||||
/// Please refer to the IRRT implementation to see its purpose.
|
/// Please refer to the IRRT implementation to see its purpose.
|
||||||
pub fn update_strides_by_shape<G: CodeGenerator + ?Sized>(
|
pub fn update_strides_by_shape<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
) {
|
) {
|
||||||
|
@ -458,7 +458,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
self.ndims == 0
|
self.ndims == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If this ndarray is unsized, return its sole value as a [`ScalarObject`]. Otherwise, do nothing.
|
/// If this ndarray is unsized, return its sole value as a [`AnyObject`]. Otherwise, do nothing.
|
||||||
pub fn split_unsized<G: CodeGenerator + ?Sized>(
|
pub fn split_unsized<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -604,10 +604,10 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
scalar: ScalarObject<'ctx>,
|
scalar: AnyObject<'ctx>,
|
||||||
) {
|
) {
|
||||||
// Sanity check on scalar's type.
|
// Sanity check on scalar's type.
|
||||||
assert!(ctx.unifier.unioned(self.dtype, scalar.dtype));
|
assert!(ctx.unifier.unioned(self.dtype, scalar.ty));
|
||||||
|
|
||||||
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| {
|
||||||
let p = nditer.get_pointer(generator, ctx);
|
let p = nditer.get_pointer(generator, ctx);
|
||||||
|
|
|
@ -3,7 +3,7 @@ use inkwell::{types::BasicType, values::PointerValue, AddressSpace};
|
||||||
use crate::codegen::{
|
use crate::codegen::{
|
||||||
irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next},
|
irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next},
|
||||||
model::*,
|
model::*,
|
||||||
object::ndarray::scalar::ScalarObject,
|
object::AnyObject,
|
||||||
structure::NDIter,
|
structure::NDIter,
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
@ -63,10 +63,10 @@ impl<'ctx> NDIterHandle<'ctx> {
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
) -> ScalarObject<'ctx> {
|
) -> AnyObject<'ctx> {
|
||||||
let p = self.get_pointer(generator, ctx);
|
let p = self.get_pointer(generator, ctx);
|
||||||
let value = ctx.builder.build_load(p, "value").unwrap();
|
let value = ctx.builder.build_load(p, "value").unwrap();
|
||||||
ScalarObject { dtype: self.ndarray.dtype, value }
|
AnyObject { ty: self.ndarray.dtype, value }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_index<G: CodeGenerator + ?Sized>(
|
pub fn get_index<G: CodeGenerator + ?Sized>(
|
||||||
|
|
|
@ -7,18 +7,7 @@ use crate::{
|
||||||
|
|
||||||
use super::NDArrayObject;
|
use super::NDArrayObject;
|
||||||
|
|
||||||
/// An LLVM numpy scalar with its [`Type`].
|
impl<'ctx> AnyObject<'ctx> {
|
||||||
///
|
|
||||||
/// Intended to be used with [`ScalarOrNDArray`].
|
|
||||||
///
|
|
||||||
/// A scalar does not have to be an actual number. It could be arbitrary objects.
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct ScalarObject<'ctx> {
|
|
||||||
pub dtype: Type,
|
|
||||||
pub value: BasicValueEnum<'ctx>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'ctx> ScalarObject<'ctx> {
|
|
||||||
/// Promote this scalar to an unsized ndarray (like doing `np.asarray`).
|
/// Promote this scalar to an unsized ndarray (like doing `np.asarray`).
|
||||||
///
|
///
|
||||||
/// The scalar value is allocated onto the stack, and the ndarray's `data` will point to that
|
/// The scalar value is allocated onto the stack, and the ndarray's `data` will point to that
|
||||||
|
@ -35,7 +24,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||||
ctx.builder.build_store(data, self.value).unwrap();
|
ctx.builder.build_store(data, self.value).unwrap();
|
||||||
let data = pbyte_model.pointer_cast(generator, ctx, data, "data");
|
let data = pbyte_model.pointer_cast(generator, ctx, data, "data");
|
||||||
|
|
||||||
let ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, 0, "scalar_ndarray");
|
let ndarray = NDArrayObject::alloca(generator, ctx, self.ty, 0, "scalar_ndarray");
|
||||||
ndarray.instance.set(ctx, |f| f.data, data);
|
ndarray.instance.set(ctx, |f| f.data, data);
|
||||||
ndarray
|
ndarray
|
||||||
}
|
}
|
||||||
|
@ -44,7 +33,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
||||||
/// A convenience enum for implementing scalar/ndarray agnostic utilities.
|
/// A convenience enum for implementing scalar/ndarray agnostic utilities.
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub enum ScalarOrNDArray<'ctx> {
|
pub enum ScalarOrNDArray<'ctx> {
|
||||||
Scalar(ScalarObject<'ctx>),
|
Scalar(AnyObject<'ctx>),
|
||||||
NDArray(NDArrayObject<'ctx>),
|
NDArray(NDArrayObject<'ctx>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,7 +48,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn into_scalar(&self) -> ScalarObject<'ctx> {
|
pub fn into_scalar(&self) -> AnyObject<'ctx> {
|
||||||
match self {
|
match self {
|
||||||
ScalarOrNDArray::NDArray(_ndarray) => panic!("Got NDArray"),
|
ScalarOrNDArray::NDArray(_ndarray) => panic!("Got NDArray"),
|
||||||
ScalarOrNDArray::Scalar(scalar) => *scalar,
|
ScalarOrNDArray::Scalar(scalar) => *scalar,
|
||||||
|
@ -91,13 +80,13 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn dtype(&self) -> Type {
|
pub fn dtype(&self) -> Type {
|
||||||
match self {
|
match self {
|
||||||
ScalarOrNDArray::Scalar(scalar) => scalar.dtype,
|
ScalarOrNDArray::Scalar(scalar) => scalar.ty,
|
||||||
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
|
ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> {
|
impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for AnyObject<'ctx> {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result<Self, Self::Error> {
|
||||||
|
@ -135,7 +124,7 @@ pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ScalarOrNDArray::NDArray(ndarray)
|
ScalarOrNDArray::NDArray(ndarray)
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let scalar = ScalarObject { dtype: object.ty, value: object.value };
|
let scalar = AnyObject { ty: object.ty, value: object.value };
|
||||||
ScalarOrNDArray::Scalar(scalar)
|
ScalarOrNDArray::Scalar(scalar)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1143,12 +1143,12 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
ret_dtype,
|
ret_dtype,
|
||||||
|generator, ctx, scalar| {
|
|generator, ctx, scalar| {
|
||||||
let result = match prim {
|
let result = match prim {
|
||||||
PrimDef::FunInt32 => scalar.cast_to_int32(generator, ctx),
|
PrimDef::FunInt32 => scalar.call_int32(generator, ctx),
|
||||||
PrimDef::FunInt64 => scalar.cast_to_int64(generator, ctx),
|
PrimDef::FunInt64 => scalar.call_int64(generator, ctx),
|
||||||
PrimDef::FunUInt32 => scalar.cast_to_uint32(generator, ctx),
|
PrimDef::FunUInt32 => scalar.call_uint32(generator, ctx),
|
||||||
PrimDef::FunUInt64 => scalar.cast_to_uint64(generator, ctx),
|
PrimDef::FunUInt64 => scalar.call_uint64(generator, ctx),
|
||||||
PrimDef::FunFloat => scalar.cast_to_float(ctx),
|
PrimDef::FunFloat => scalar.call_float(ctx),
|
||||||
PrimDef::FunBool => scalar.cast_to_bool(ctx),
|
PrimDef::FunBool => scalar.call_bool(ctx),
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
Ok(result.value)
|
Ok(result.value)
|
||||||
|
@ -1277,7 +1277,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
ctx,
|
ctx,
|
||||||
int_sized,
|
int_sized,
|
||||||
|generator, ctx, scalar| {
|
|generator, ctx, scalar| {
|
||||||
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value)
|
Ok(scalar.call_floor_or_ceil(generator, ctx, kind, int_sized).value)
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(result.to_basic_value_enum()))
|
Ok(Some(result.to_basic_value_enum()))
|
||||||
|
@ -1927,7 +1927,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
let arg = AnyObject { value: arg, ty: arg_ty };
|
let arg = AnyObject { value: arg, ty: arg_ty };
|
||||||
|
|
||||||
Ok(Some(arg.len(generator, ctx).value.as_basic_value_enum()))
|
Ok(Some(arg.call_len(generator, ctx).value.as_basic_value_enum()))
|
||||||
},
|
},
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
|
|
Loading…
Reference in New Issue