From 4b765cfb270785519f5343344ff9b3284589c75b Mon Sep 17 00:00:00 2001 From: lyken Date: Thu, 15 Aug 2024 13:34:48 +0800 Subject: [PATCH] WIP: core/ndstrides: remove ScalarObject --- nac3core/src/codegen/object/mod.rs | 569 +++++++++++++++++- .../src/codegen/object/ndarray/factory.rs | 14 +- .../src/codegen/object/ndarray/functions.rs | 456 +------------- .../src/codegen/object/ndarray/mapping.rs | 49 +- nac3core/src/codegen/object/ndarray/mod.rs | 18 +- nac3core/src/codegen/object/ndarray/nditer.rs | 6 +- nac3core/src/codegen/object/ndarray/scalar.rs | 25 +- nac3core/src/toplevel/builtins.rs | 16 +- 8 files changed, 627 insertions(+), 526 deletions(-) diff --git a/nac3core/src/codegen/object/mod.rs b/nac3core/src/codegen/object/mod.rs index 729d6984..82344f3e 100644 --- a/nac3core/src/codegen/object/mod.rs +++ b/nac3core/src/codegen/object/mod.rs @@ -1,18 +1,71 @@ -use inkwell::values::BasicValueEnum; +use inkwell::{ + values::{BasicValue, BasicValueEnum, FloatValue, IntValue}, + FloatPredicate, IntPredicate, +}; +use itertools::Itertools; use list::ListObject; -use ndarray::NDArrayObject; +use ndarray::{NDArrayObject, NDArrayOut}; use range::RangeObject; use tuple::TupleObject; -use crate::typecheck::typedef::{Type, TypeEnum}; +use crate::{ + toplevel::helper::PrimDef, + typecheck::typedef::{Type, TypeEnum}, +}; -use super::{model::*, CodeGenContext, CodeGenerator}; +use super::{llvm_intrinsics, model::*, CodeGenContext, CodeGenerator}; pub mod list; pub mod ndarray; pub mod range; pub mod tuple; +/// Convenience function to crash the program when types of arguments are not supported. +/// Used to be debugged with a stacktrace. +fn unsupported_type(ctx: &CodeGenContext<'_, '_>, tys: I) -> ! +where + I: IntoIterator, +{ + unreachable!( + "unsupported types found '{}'", + tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "), + ) +} + +#[derive(Debug, Clone, Copy)] +pub enum FloorOrCeil { + Floor, + Ceil, +} + +#[derive(Debug, Clone, Copy)] +pub enum MinOrMax { + Min, + Max, +} + +fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec { + vec![ctx.primitives.int32, ctx.primitives.int64] +} + +fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec { + vec![ctx.primitives.uint32, ctx.primitives.uint64] +} + +fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec { + vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64] +} + +fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec { + vec![ + ctx.primitives.bool, + ctx.primitives.int32, + ctx.primitives.int64, + ctx.primitives.uint32, + ctx.primitives.uint64, + ] +} + #[derive(Debug, Clone, Copy)] pub struct AnyObject<'ctx> { pub ty: Type, @@ -20,13 +73,260 @@ pub struct AnyObject<'ctx> { } impl<'ctx> AnyObject<'ctx> { - // Get the `len()` of this object. - pub fn len( + /// Returns true if this object's type is a [`TypeEnum::TObj`] and has the object ID as `prim`. + pub fn is_obj(&self, ctx: &mut CodeGenContext<'ctx, '_>, prim: PrimDef) -> bool { + match &*ctx.unifier.get_ty(self.ty) { + TypeEnum::TObj { obj_id, .. } => *obj_id == prim.id(), + _ => false, + } + } + + /// Returns true if this object's type is a [`TypeEnum::TTuple`] + pub fn is_tuple(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { + matches!(&*ctx.unifier.get_ty(self.ty), TypeEnum::TTuple { .. }) + } + + pub fn is_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { + ctx.unifier.unioned(self.ty, ctx.primitives.int32) + } + + pub fn is_uint32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { + ctx.unifier.unioned(self.ty, ctx.primitives.uint32) + } + + pub fn is_int64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { + ctx.unifier.unioned(self.ty, ctx.primitives.int64) + } + + pub fn is_uint64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { + ctx.unifier.unioned(self.ty, ctx.primitives.uint64) + } + + pub fn is_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { + ctx.unifier.unioned(self.ty, ctx.primitives.bool) + } + + pub fn is_int_like(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { + ctx.unifier.unioned_any(self.ty, int_like(ctx)) + } + + pub fn is_signed_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { + ctx.unifier.unioned_any(self.ty, signed_ints(ctx)) + } + + pub fn is_unsigned_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { + ctx.unifier.unioned_any(self.ty, unsigned_ints(ctx)) + } + + /// Convenience function. If object has type `int32`, `int64`, `uint32`, `uint64`, or `bool`, + /// get its underlying LLVM value. + /// + /// Panic if the type is wrong. + pub fn into_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { + if self.is_int_like(ctx) { + self.value.into_int_value() + } else { + panic!("not an int32 type") + } + } + + pub fn is_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { + self.is_obj(ctx, PrimDef::Float) + } + + /// Convenience function. If object has type `float`, get its underlying LLVM value. + /// + /// Panic if the type is wrong. + pub fn into_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Float<'ctx, Float64> { + if self.is_float(ctx) { + // self.value must be a FloatValue + FloatModel(Float64).believe_value(self.value.into_float_value()) + } else { + panic!("not a float type") + } + } + + pub fn is_ndarray(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { + self.is_obj(ctx, PrimDef::NDArray) + } + + pub fn into_ndarray( &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - ) -> Int<'ctx, Int32> { - match &*ctx.unifier.get_ty_immutable(self.ty) { + ) -> NDArrayObject<'ctx> { + NDArrayObject::from_object(generator, ctx, *self) + } + + /// Helper function to compare two scalars. + /// + /// Only int-to-int and float-to-float comparisons are allowed. + /// + /// Panic otherwise. + pub fn compare_int_or_float_by_predicate( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + lhs: AnyObject<'ctx>, + rhs: AnyObject<'ctx>, + int_predicate: IntPredicate, + float_predicate: FloatPredicate, + name: &str, + ) -> Int<'ctx, Bool> { + if !ctx.unifier.unioned(lhs.ty, rhs.ty) { + panic!("lhs and rhs type are not the same.") + } + + let bool_model = IntModel(Bool); + + let common_ty = lhs.ty; + let result = if lhs.is_float(ctx) { + let lhs = lhs.into_float(ctx); + let rhs = rhs.into_float(ctx); + ctx.builder.build_float_compare(float_predicate, lhs.value, rhs.value, name).unwrap() + } else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) { + let lhs = lhs.into_int(ctx); + let rhs = rhs.into_int(ctx); + ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap() + } else { + unsupported_type(ctx, [lhs.ty, rhs.ty]); + }; + + bool_model.check_value(generator, ctx.ctx, result).unwrap() + } + + /// Helper function for `int32()`, `int64()`, `uint32()`, and `uint64()`. + /// + /// TODO: Document me + fn cast_to_int_conversion<'a, G, HandleFloatFn>( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, 'a>, + ret_int_ty: Type, + handle_float: HandleFloatFn, + ) -> AnyObject<'ctx> + where + G: CodeGenerator + ?Sized, + HandleFloatFn: + FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>, + { + let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type(); + + let result = if self.is_float(ctx) { + // Handle float to int + let n = self.into_float(ctx); + handle_float(generator, ctx, n.value) + } else if self.is_int_like(ctx) { + // Handle int to a new int type + let n = self.into_int(ctx); + if n.get_type().get_bit_width() <= ret_int_ty_llvm.get_bit_width() { + ctx.builder.build_int_z_extend(n, ret_int_ty_llvm, "zext").unwrap() + } else { + ctx.builder.build_int_truncate(n, ret_int_ty_llvm, "trunc").unwrap() + } + } else { + unsupported_type(ctx, [self.ty]); + }; + + assert_eq!(ret_int_ty_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check + AnyObject { value: result.into(), ty: ret_int_ty } + } + + /// Call `int32()` on this object. + #[must_use] + pub fn call_int32( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> AnyObject<'ctx> { + self.cast_to_int_conversion( + generator, + ctx, + ctx.primitives.int32, + |_generator, ctx, input| { + let n = + ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap(); + ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap() + }, + ) + } + + /// Call `int64()` on this object. + #[must_use] + pub fn call_int64( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> AnyObject<'ctx> { + self.cast_to_int_conversion( + generator, + ctx, + ctx.primitives.int64, + |_generator, ctx, input| { + ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap() + }, + ) + } + + /// Call `uint32()` on this object. + #[must_use] + pub fn call_uint32( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> AnyObject<'ctx> { + self.cast_to_int_conversion(generator, ctx, ctx.primitives.uint32, |_generator, ctx, n| { + let n_gez = ctx + .builder + .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") + .unwrap(); + + let to_int32 = + ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap(); + let to_uint64 = + ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); + + ctx.builder + .build_select( + n_gez, + ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(), + to_int32, + "conv", + ) + .unwrap() + .into_int_value() + }) + } + + /// Call `uint64()` on this object. + #[must_use] + pub fn call_uint64( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> AnyObject<'ctx> { + self.cast_to_int_conversion(generator, ctx, ctx.primitives.uint64, |_generator, ctx, n| { + let val_gez = ctx + .builder + .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") + .unwrap(); + + let to_int64 = + ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap(); + let to_uint64 = + ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); + + ctx.builder.build_select(val_gez, to_uint64, to_int64, "conv").unwrap().into_int_value() + }) + } + + // Get the `len()` of this object. + pub fn call_len( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> AnyObject<'ctx> { + // TODO: Switch to returning SizeT + let result = match &*ctx.unifier.get_ty_immutable(self.ty) { TypeEnum::TTuple { .. } => { let tuple = TupleObject::from_object(ctx, *self); tuple.len(generator, ctx).truncate(generator, ctx, Int32, "tuple_len_32") @@ -50,6 +350,259 @@ impl<'ctx> AnyObject<'ctx> { ndarray.len(generator, ctx).truncate(generator, ctx, Int32, "ndarray_len_i32") } _ => unreachable!(), + }; + + AnyObject { ty: ctx.primitives.int32, value: result.value.as_basic_value_enum() } + } + + /// Like [`AnyObject::call_bool`] but this returns an `Int<'ctx, Bool>` instead of an object. + pub fn bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Int<'ctx, Bool> { + let bool_model = IntModel(Bool); + if self.is_int_like(ctx) { + let n = self.into_int(ctx); + let n = ctx + .builder + .build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool") + .unwrap(); + bool_model.believe_value(n) + } else if self.is_float(ctx) { + let n = self.value.into_float_value(); + let n = ctx + .builder + .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool") + .unwrap(); + bool_model.believe_value(n) + } else { + unsupported_type(ctx, [self.ty]) + } + } + + /// Call `bool()` on this object. + pub fn call_bool( + &self, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> AnyObject<'ctx> { + let n = self.bool(ctx); + + // NAC3 booleans are i8 + let llvm_i8 = ctx.ctx.i8_type(); + let n = ctx.builder.build_int_z_extend(n.value, llvm_i8, "bool").unwrap(); + + AnyObject { ty: ctx.primitives.bool, value: n.as_basic_value_enum() } + } + + /// Call `float()` on this object. + pub fn call_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> { + let f64_model = FloatModel(Float64); + let llvm_f64 = ctx.ctx.f64_type(); + + let result = if self.is_float(ctx) { + self.into_float(ctx) + } else if self.is_signed_int(ctx) || self.is_bool(ctx) { + let n = self.into_int(ctx); + let n = ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap(); + f64_model.believe_value(n) + } else if self.is_unsigned_int(ctx) { + let n = self.into_int(ctx); + let n = ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap(); + f64_model.believe_value(n) + } else { + unsupported_type(ctx, [self.ty]); + }; + + AnyObject { ty: ctx.primitives.float, value: result.value.as_basic_value_enum() } + } + + // Call `abs()` on this object. + #[must_use] + pub fn call_abs( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> Self { + if self.is_float(ctx) { + let n = self.value.into_float_value(); + let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs")); + AnyObject { value: n.into(), ty: ctx.primitives.float } + } else if self.is_unsigned_int(ctx) || self.is_signed_int(ctx) { + let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false + + let n = self.value.into_int_value(); + let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs")); + AnyObject { value: n.into(), ty: self.ty } + } else if self.is_ndarray(ctx) { + let ndarray = self.into_ndarray(generator, ctx); + ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ndarray.dtype }, + |generator, ctx, scalar| Ok(scalar.call_abs(generator, ctx)), + ) + .unwrap() + .to_any_object(ctx) + } else { + unsupported_type(ctx, [self.ty]) + } + } + + // Call `round()` on this object. + // + // It is possible to specify which kind of int type to return. + pub fn call_round( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ret_int_ty: Type, + ) -> AnyObject<'ctx> { + let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type(); + + let result = if ctx.unifier.unioned(self.ty, ctx.primitives.float) { + let n = self.value.into_float_value(); + let n = llvm_intrinsics::call_float_round(ctx, n, None); + ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, "round").unwrap() + } else { + unsupported_type(ctx, [self.ty]) + }; + AnyObject { ty: ret_int_ty, value: result.as_basic_value_enum() } + } + + /// Call `np_round()` on this object. + /// + /// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type. + #[must_use] + pub fn call_np_round( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + ) -> AnyObject<'ctx> { + if self.is_float(ctx) { + let n = self.into_float(ctx); + let n = llvm_intrinsics::call_float_rint(ctx, n.value, None); + AnyObject { ty: ctx.primitives.float, value: n.as_basic_value_enum() } + } else if self.is_ndarray(ctx) { + let ndarray = self.into_ndarray(generator, ctx); + ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ndarray.dtype }, + |generator, ctx, scalar| Ok(scalar.call_np_round(generator, ctx)), + ) + .unwrap() + .to_any_object(ctx) + } else { + unsupported_type(ctx, [self.ty]) + } + } + + /// Call `min()` or `max()` on two objects. + pub fn call_min_or_max( + ctx: &mut CodeGenContext<'ctx, '_>, + kind: MinOrMax, + a: AnyObject<'ctx>, + b: AnyObject<'ctx>, + ) -> AnyObject<'ctx> { + if !ctx.unifier.unioned(a.ty, b.ty) { + unsupported_type(ctx, [a.ty, b.ty]) + } + + let common_ty = a.ty; + + if a.is_float(ctx) { + let function = match kind { + MinOrMax::Min => llvm_intrinsics::call_float_minnum, + MinOrMax::Max => llvm_intrinsics::call_float_maxnum, + }; + + let a = a.into_float(ctx).value; + let b = b.into_float(ctx).value; + let result = function(ctx, a, b, None); + AnyObject { value: result.as_basic_value_enum(), ty: ctx.primitives.float } + } else if a.is_unsigned_int(ctx) || a.is_bool(ctx) { + // Treating bool has an unsigned int since that is convenient + let function = match kind { + MinOrMax::Min => llvm_intrinsics::call_int_umin, + MinOrMax::Max => llvm_intrinsics::call_int_umax, + }; + + let a = a.into_int(ctx); + let b = b.into_int(ctx); + let result = function(ctx, a, b, None); + AnyObject { value: result.as_basic_value_enum(), ty: common_ty } + } else if a.is_signed_int(ctx) { + let function = match kind { + MinOrMax::Min => llvm_intrinsics::call_int_smin, + MinOrMax::Max => llvm_intrinsics::call_int_smax, + }; + + let a = a.into_int(ctx); + let b = b.into_int(ctx); + let result = function(ctx, a, b, None); + AnyObject { value: result.as_basic_value_enum(), ty: common_ty } + } else { + unsupported_type(ctx, [common_ty]) + } + } + + /// Call `floor()` or `ceil()` on this object. + /// + /// It is possible to specify which kind of int type to return. + #[must_use] + pub fn call_floor_or_ceil( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + kind: FloorOrCeil, + ret_int_ty: Type, + ) -> Self { + let ret_int_ty_llvm = ctx.get_llvm_type(generator, ret_int_ty).into_int_type(); + + if self.is_float(ctx) { + let function = match kind { + FloorOrCeil::Floor => llvm_intrinsics::call_float_floor, + FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil, + }; + + let n = self.into_float(ctx).value; + let n = function(ctx, n, None); + let n = ctx.builder.build_float_to_signed_int(n, ret_int_ty_llvm, "").unwrap(); + AnyObject { ty: ret_int_ty, value: n.as_basic_value_enum() } + } else { + unsupported_type(ctx, [self.ty]) + } + } + + /// Call `np_floor()` or `np_ceil()` on this object. + #[must_use] + pub fn call_np_floor_or_ceil( + &self, + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + kind: FloorOrCeil, + ) -> Self { + // TODO: + if self.is_float(ctx) { + let function = match kind { + FloorOrCeil::Floor => llvm_intrinsics::call_float_floor, + FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil, + }; + let n = self.into_float(ctx).value; + let n = function(ctx, n, None); + AnyObject { ty: ctx.primitives.float, value: n.as_basic_value_enum() } + } else if self.is_ndarray(ctx) { + let ndarray = self.into_ndarray(generator, ctx); + ndarray + .map( + generator, + ctx, + NDArrayOut::NewNDArray { dtype: ctx.primitives.float }, + |generator, ctx, scalar| Ok(scalar.call_np_floor_or_ceil(generator, ctx, kind)), + ) + .unwrap() + .to_any_object(ctx) + } else { + unsupported_type(ctx, [self.ty]) } } } diff --git a/nac3core/src/codegen/object/ndarray/factory.rs b/nac3core/src/codegen/object/ndarray/factory.rs index c5d0ab4a..68294340 100644 --- a/nac3core/src/codegen/object/ndarray/factory.rs +++ b/nac3core/src/codegen/object/ndarray/factory.rs @@ -1,10 +1,10 @@ use inkwell::{values::BasicValueEnum, IntPredicate}; -use super::{scalar::ScalarObject, NDArrayObject}; +use super::NDArrayObject; use crate::{ codegen::{ - irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, CodeGenContext, - CodeGenerator, + irrt::call_nac3_ndarray_util_assert_shape_no_negative, model::*, object::AnyObject, + CodeGenContext, CodeGenerator, }, typecheck::typedef::Type, }; @@ -93,10 +93,10 @@ impl<'ctx> NDArrayObject<'ctx> { dtype: Type, ndims: u64, shape: Ptr<'ctx, IntModel>, - fill_value: ScalarObject<'ctx>, + fill_value: AnyObject<'ctx>, ) -> Self { // Sanity check on `fill_value`'s dtype. - assert!(ctx.unifier.unioned(dtype, fill_value.dtype)); + assert!(ctx.unifier.unioned(dtype, fill_value.ty)); let ndarray = NDArrayObject::from_np_empty(generator, ctx, dtype, ndims, shape); ndarray.fill(generator, ctx, fill_value); @@ -112,7 +112,7 @@ impl<'ctx> NDArrayObject<'ctx> { shape: Ptr<'ctx, IntModel>, ) -> Self { let fill_value = ndarray_zero_value(generator, ctx, dtype); - let fill_value = ScalarObject { value: fill_value, dtype }; + let fill_value = AnyObject { value: fill_value, ty: dtype }; NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value) } @@ -125,7 +125,7 @@ impl<'ctx> NDArrayObject<'ctx> { shape: Ptr<'ctx, IntModel>, ) -> Self { let fill_value = ndarray_one_value(generator, ctx, dtype); - let fill_value = ScalarObject { value: fill_value, dtype }; + let fill_value = AnyObject { value: fill_value, ty: dtype }; NDArrayObject::from_np_full(generator, ctx, dtype, ndims, shape, fill_value) } diff --git a/nac3core/src/codegen/object/ndarray/functions.rs b/nac3core/src/codegen/object/ndarray/functions.rs index 4ba9a0e0..0ecb97c5 100644 --- a/nac3core/src/codegen/object/ndarray/functions.rs +++ b/nac3core/src/codegen/object/ndarray/functions.rs @@ -1,443 +1,13 @@ -use inkwell::{ - values::{BasicValue, FloatValue, IntValue}, - FloatPredicate, IntPredicate, -}; -use itertools::Itertools; +use inkwell::{FloatPredicate, IntPredicate}; -use crate::{ - codegen::{llvm_intrinsics, model::*, stmt::gen_if_callback, CodeGenContext, CodeGenerator}, - typecheck::typedef::Type, +use crate::codegen::{ + model::*, + object::{AnyObject, MinOrMax}, + stmt::gen_if_callback, + CodeGenContext, CodeGenerator, }; -use super::{scalar::ScalarObject, NDArrayObject}; - -/// Convenience function to crash the program when types of arguments are not supported. -/// Used to be debugged with a stacktrace. -fn unsupported_type(ctx: &CodeGenContext<'_, '_>, tys: I) -> ! -where - I: IntoIterator, -{ - unreachable!( - "unsupported types found '{}'", - tys.into_iter().map(|ty| format!("'{}'", ctx.unifier.stringify(ty))).join(", "), - ) -} - -#[derive(Debug, Clone, Copy)] -pub enum FloorOrCeil { - Floor, - Ceil, -} - -#[derive(Debug, Clone, Copy)] -pub enum MinOrMax { - Min, - Max, -} - -fn signed_ints(ctx: &CodeGenContext<'_, '_>) -> Vec { - vec![ctx.primitives.int32, ctx.primitives.int64] -} - -fn unsigned_ints(ctx: &CodeGenContext<'_, '_>) -> Vec { - vec![ctx.primitives.uint32, ctx.primitives.uint64] -} - -fn ints(ctx: &CodeGenContext<'_, '_>) -> Vec { - vec![ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64] -} - -fn int_like(ctx: &CodeGenContext<'_, '_>) -> Vec { - vec![ - ctx.primitives.bool, - ctx.primitives.int32, - ctx.primitives.int64, - ctx.primitives.uint32, - ctx.primitives.uint64, - ] -} - -fn cast_to_int_conversion<'ctx, 'a, G, HandleFloatFn>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, 'a>, - scalar: ScalarObject<'ctx>, - ret_int_dtype: Type, - handle_float: HandleFloatFn, -) -> ScalarObject<'ctx> -where - G: CodeGenerator + ?Sized, - HandleFloatFn: - FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>, -{ - let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type(); - - let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) { - // Special handling for floats - let n = scalar.value.into_float_value(); - handle_float(generator, ctx, n) - } else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) { - let n = scalar.value.into_int_value(); - - if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() { - ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap() - } else { - ctx.builder.build_int_truncate(n, ret_int_dtype_llvm, "trunc").unwrap() - } - } else { - unsupported_type(ctx, [scalar.dtype]); - }; - - assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check - ScalarObject { value: result.into(), dtype: ret_int_dtype } -} - -impl<'ctx> ScalarObject<'ctx> { - /// Convenience function. Assume this scalar has typechecker type float64, get its underlying LLVM value. - /// - /// Panic if the type is wrong. - pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> { - if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - self.value.into_float_value() // self.value must be a FloatValue - } else { - panic!("not a float type") - } - } - - /// Convenience function. Assume this scalar has typechecker type int32, get its underlying LLVM value. - /// - /// Panic if the type is wrong. - pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { - if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) { - let value = self.value.into_int_value(); - debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check - value - } else { - panic!("not a float type") - } - } - - /// Compare two scalars. Only int-to-int and float-to-float comparisons are allowed. - /// Panic otherwise. - pub fn compare( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - lhs: ScalarObject<'ctx>, - rhs: ScalarObject<'ctx>, - int_predicate: IntPredicate, - float_predicate: FloatPredicate, - name: &str, - ) -> Int<'ctx, Bool> { - if !ctx.unifier.unioned(lhs.dtype, rhs.dtype) { - unsupported_type(ctx, [lhs.dtype, rhs.dtype]); - } - - let bool_model = IntModel(Bool); - - let common_ty = lhs.dtype; - let result = if ctx.unifier.unioned(common_ty, ctx.primitives.float) { - let lhs = lhs.value.into_float_value(); - let rhs = rhs.value.into_float_value(); - ctx.builder.build_float_compare(float_predicate, lhs, rhs, name).unwrap() - } else if ctx.unifier.unioned_any(common_ty, int_like(ctx)) { - let lhs = lhs.value.into_int_value(); - let rhs = rhs.value.into_int_value(); - ctx.builder.build_int_compare(int_predicate, lhs, rhs, name).unwrap() - } else { - unsupported_type(ctx, [lhs.dtype, rhs.dtype]); - }; - - bool_model.check_value(generator, ctx.ctx, result).unwrap() - } - - /// Invoke NAC3's builtin `int32()`. - #[must_use] - pub fn cast_to_int32( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ) -> Self { - cast_to_int_conversion( - generator, - ctx, - *self, - ctx.primitives.int32, - |_generator, ctx, input| { - let n = - ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap(); - ctx.builder.build_int_truncate(n, ctx.ctx.i32_type(), "conv").unwrap() - }, - ) - } - - /// Invoke NAC3's builtin `int64()`. - #[must_use] - pub fn cast_to_int64( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ) -> Self { - cast_to_int_conversion( - generator, - ctx, - *self, - ctx.primitives.int64, - |_generator, ctx, input| { - ctx.builder.build_float_to_signed_int(input, ctx.ctx.i64_type(), "").unwrap() - }, - ) - } - - /// Invoke NAC3's builtin `uint32()`. - #[must_use] - pub fn cast_to_uint32( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ) -> Self { - cast_to_int_conversion( - generator, - ctx, - *self, - ctx.primitives.uint32, - |_generator, ctx, n| { - let n_gez = ctx - .builder - .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") - .unwrap(); - - let to_int32 = - ctx.builder.build_float_to_signed_int(n, ctx.ctx.i32_type(), "").unwrap(); - let to_uint64 = - ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); - - ctx.builder - .build_select( - n_gez, - ctx.builder.build_int_truncate(to_uint64, ctx.ctx.i32_type(), "").unwrap(), - to_int32, - "conv", - ) - .unwrap() - .into_int_value() - }, - ) - } - - /// Invoke NAC3's builtin `uint64()`. - #[must_use] - pub fn cast_to_uint64( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ) -> Self { - cast_to_int_conversion( - generator, - ctx, - *self, - ctx.primitives.uint64, - |_generator, ctx, n| { - let val_gez = ctx - .builder - .build_float_compare(FloatPredicate::OGE, n, n.get_type().const_zero(), "") - .unwrap(); - - let to_int64 = - ctx.builder.build_float_to_signed_int(n, ctx.ctx.i64_type(), "").unwrap(); - let to_uint64 = - ctx.builder.build_float_to_unsigned_int(n, ctx.ctx.i64_type(), "").unwrap(); - - ctx.builder - .build_select(val_gez, to_uint64, to_int64, "conv") - .unwrap() - .into_int_value() - }, - ) - } - - /// Invoke NAC3's builtin `bool()`. - #[must_use] - pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { - // TODO: Why is the original code being so lax about i1 and i8 for the returned int type? - let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) { - self.value.into_int_value() - } else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) { - let n = self.value.into_int_value(); - ctx.builder - .build_int_compare(inkwell::IntPredicate::NE, n, n.get_type().const_zero(), "bool") - .unwrap() - } else if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - let n = self.value.into_float_value(); - ctx.builder - .build_float_compare(FloatPredicate::UNE, n, n.get_type().const_zero(), "bool") - .unwrap() - } else { - unsupported_type(ctx, [self.dtype]) - }; - - ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() } - } - - /// Invoke NAC3's builtin `float()`. - #[must_use] - pub fn cast_to_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { - let llvm_f64 = ctx.ctx.f64_type(); - - let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - self.value.into_float_value() - } else if ctx - .unifier - .unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat()) - { - let n = self.value.into_int_value(); - ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap() - } else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) { - let n = self.value.into_int_value(); - ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap() - } else { - unsupported_type(ctx, [self.dtype]); - }; - - ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float } - } - - /// Invoke NAC3's builtin `round()`. - #[must_use] - pub fn round( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - ret_int_dtype: Type, - ) -> Self { - let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type(); - - let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - let n = self.value.into_float_value(); - let n = llvm_intrinsics::call_float_round(ctx, n, None); - ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap() - } else { - unsupported_type(ctx, [self.dtype, ret_int_dtype]) - }; - ScalarObject { dtype: ret_int_dtype, value: result.as_basic_value_enum() } - } - - /// Invoke NAC3's builtin `np_round()`. - /// - /// NOTE: `np.round()` has different behaviors than `round()` when in comes to "tie" cases and return type. - #[must_use] - pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { - let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - let n = self.value.into_float_value(); - llvm_intrinsics::call_float_rint(ctx, n, None) - } else { - unsupported_type(ctx, [self.dtype]) - }; - ScalarObject { dtype: ctx.primitives.float, value: result.as_basic_value_enum() } - } - - /// Invoke NAC3's builtin `min()` or `max()`. - pub fn min_or_max( - ctx: &mut CodeGenContext<'ctx, '_>, - kind: MinOrMax, - a: Self, - b: Self, - ) -> Self { - if !ctx.unifier.unioned(a.dtype, b.dtype) { - unsupported_type(ctx, [a.dtype, b.dtype]) - } - - let common_dtype = a.dtype; - - if ctx.unifier.unioned(common_dtype, ctx.primitives.float) { - let function = match kind { - MinOrMax::Min => llvm_intrinsics::call_float_minnum, - MinOrMax::Max => llvm_intrinsics::call_float_maxnum, - }; - let result = - function(ctx, a.value.into_float_value(), b.value.into_float_value(), None); - ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float } - } else if ctx.unifier.unioned_any( - common_dtype, - [unsigned_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat(), - ) { - // Treating bool has an unsigned int since that is convenient - let function = match kind { - MinOrMax::Min => llvm_intrinsics::call_int_umin, - MinOrMax::Max => llvm_intrinsics::call_int_umax, - }; - let result = function(ctx, a.value.into_int_value(), b.value.into_int_value(), None); - ScalarObject { value: result.as_basic_value_enum(), dtype: common_dtype } - } else { - unsupported_type(ctx, [common_dtype]) - } - } - - /// Invoke NAC3's builtin `floor()` or `ceil()`. - /// - /// * `ret_int_dtype` - The type of int to return. - /// - /// Takes in a float/int and returns an int of type `ret_int_dtype` - #[must_use] - pub fn floor_or_ceil( - &self, - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - kind: FloorOrCeil, - ret_int_dtype: Type, - ) -> Self { - let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type(); - - if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - let function = match kind { - FloorOrCeil::Floor => llvm_intrinsics::call_float_floor, - FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil, - }; - let n = self.value.into_float_value(); - let n = function(ctx, n, None); - - let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap(); - ScalarObject { dtype: ret_int_dtype, value: n.as_basic_value_enum() } - } else { - unsupported_type(ctx, [self.dtype]) - } - } - - /// Invoke NAC3's builtin `np_floor()`/ `np_ceil()`. - /// - /// Takes in a float/int and returns a float64 result. - #[must_use] - pub fn np_floor_or_ceil(&self, ctx: &mut CodeGenContext<'ctx, '_>, kind: FloorOrCeil) -> Self { - if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - let function = match kind { - FloorOrCeil::Floor => llvm_intrinsics::call_float_floor, - FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil, - }; - let n = self.value.into_float_value(); - let n = function(ctx, n, None); - ScalarObject { dtype: ctx.primitives.float, value: n.as_basic_value_enum() } - } else { - unsupported_type(ctx, [self.dtype]) - } - } - - /// Invoke NAC3's builtin `abs()`. - #[must_use] - pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self { - if ctx.unifier.unioned(self.dtype, ctx.primitives.float) { - let n = self.value.into_float_value(); - let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs")); - ScalarObject { value: n.into(), dtype: ctx.primitives.float } - } else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) { - let n = self.value.into_int_value(); - - let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false - let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs")); - - ScalarObject { value: n.into(), dtype: self.dtype } - } else { - unsupported_type(ctx, [self.dtype]) - } - } -} +use super::NDArrayObject; impl<'ctx> NDArrayObject<'ctx> { /// Helper function to implement NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`. @@ -451,7 +21,7 @@ impl<'ctx> NDArrayObject<'ctx> { ctx: &mut CodeGenContext<'ctx, '_>, kind: MinOrMax, on_empty_err_msg: &str, - ) -> (ScalarObject<'ctx>, Int<'ctx, SizeT>) { + ) -> (AnyObject<'ctx>, Int<'ctx, SizeT>) { let sizet_model = IntModel(SizeT); let dtype_llvm = ctx.get_llvm_type(generator, self.dtype); @@ -478,17 +48,17 @@ impl<'ctx> NDArrayObject<'ctx> { self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| { let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap(); - let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum }; + let old_extremum = AnyObject { ty: self.dtype, value: old_extremum }; let scalar = nditer.get_scalar(generator, ctx); - let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar); + let new_extremum = AnyObject::call_min_or_max(ctx, kind, old_extremum, scalar); gen_if_callback( generator, ctx, |generator, ctx| { // Is new_extremum is more extreme than old_extremum? - let cmp = ScalarObject::compare( + let cmp = AnyObject::compare_int_or_float_by_predicate( generator, ctx, new_extremum, @@ -517,7 +87,7 @@ impl<'ctx> NDArrayObject<'ctx> { let extremum_index = pextremum_index.load(generator, ctx, "extremum_index"); let extremum = ctx.builder.build_load(pextremum, "extremum_value").unwrap(); - let extremum = ScalarObject { dtype: self.dtype, value: extremum }; + let extremum = AnyObject { ty: self.dtype, value: extremum }; (extremum, extremum_index) } @@ -528,7 +98,7 @@ impl<'ctx> NDArrayObject<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, kind: MinOrMax, - ) -> ScalarObject<'ctx> { + ) -> AnyObject<'ctx> { let on_empty_err_msg = format!( "zero-size array to reduction operation {} which has no identity", match kind { diff --git a/nac3core/src/codegen/object/ndarray/mapping.rs b/nac3core/src/codegen/object/ndarray/mapping.rs index 6d30019d..85720590 100644 --- a/nac3core/src/codegen/object/ndarray/mapping.rs +++ b/nac3core/src/codegen/object/ndarray/mapping.rs @@ -1,9 +1,8 @@ -use inkwell::values::BasicValueEnum; use itertools::Itertools; use crate::{ codegen::{ - object::ndarray::{NDArrayObject, ScalarObject}, + object::ndarray::{AnyObject, NDArrayObject}, stmt::gen_for_callback, CodeGenContext, CodeGenerator, }, @@ -27,8 +26,8 @@ impl<'ctx> NDArrayObject<'ctx> { MappingFn: FnOnce( &mut G, &mut CodeGenContext<'ctx, 'a>, - &[ScalarObject<'ctx>], - ) -> Result, String>, + &[AnyObject<'ctx>], + ) -> Result, String>, { // Broadcast inputs let broadcast_result = NDArrayObject::broadcast(generator, ctx, ndarrays); @@ -89,9 +88,11 @@ impl<'ctx> NDArrayObject<'ctx> { in_nditers.iter().map(|nditer| nditer.get_scalar(generator, ctx)).collect_vec(); let result = mapping(generator, ctx, &in_scalars)?; + // Sanity check on result's ty + assert!(ctx.unifier.unioned(result.ty, out_ndarray.dtype)); let p = out_nditer.get_pointer(generator, ctx); - ctx.builder.build_store(p, result).unwrap(); + ctx.builder.build_store(p, result.value).unwrap(); Ok(()) }, @@ -103,20 +104,6 @@ impl<'ctx> NDArrayObject<'ctx> { }, )?; - // let start = sizet_model.const_0(generator, ctx.ctx); - // let stop = broadcast_result.ndarrays[0].size(generator, ctx); // They all should have the same `np.size`. - // let step = sizet_model.const_1(generator, ctx.ctx); - // gen_for_model_auto(generator, ctx, start, stop, step, move |generator, ctx, _hooks, i| { - // let elements = - // ndarrays.iter().map(|ndarray| ndarray.get_nth_scalar(generator, ctx, i)).collect_vec(); - - // let ret = mapping(generator, ctx, i, &elements)?; - - // let pret = out_ndarray.get_nth_pointer(generator, ctx, i, "pret"); - // ctx.builder.build_store(pret, ret).unwrap(); - // Ok(()) - // })?; - Ok(out_ndarray) } @@ -132,8 +119,8 @@ impl<'ctx> NDArrayObject<'ctx> { Mapping: FnOnce( &mut G, &mut CodeGenContext<'ctx, 'a>, - ScalarObject<'ctx>, - ) -> Result, String>, + AnyObject<'ctx>, + ) -> Result, String>, { NDArrayObject::broadcasting_starmap( generator, @@ -160,16 +147,18 @@ impl<'ctx> ScalarOrNDArray<'ctx> { MappingFn: FnOnce( &mut G, &mut CodeGenContext<'ctx, 'a>, - &[ScalarObject<'ctx>], - ) -> Result, String>, + &[AnyObject<'ctx>], + ) -> Result, String>, { - // Check if all inputs are ScalarObjects - let all_scalars: Option> = - inputs.iter().map(ScalarObject::try_from).try_collect().ok(); + // Check if all inputs are AnyObjects + let all_scalars: Option> = inputs.iter().map(AnyObject::try_from).try_collect().ok(); if let Some(scalars) = all_scalars { - let scalar = - ScalarObject { value: mapping(generator, ctx, &scalars)?, dtype: ret_dtype }; + let scalar = mapping(generator, ctx, &scalars)?; + + // Sanity check on scalar's type + assert!(ctx.unifier.unioned(scalar.ty, ret_dtype)); + Ok(ScalarOrNDArray::Scalar(scalar)) } else { // Promote all input to ndarrays and map through them. @@ -197,8 +186,8 @@ impl<'ctx> ScalarOrNDArray<'ctx> { Mapping: FnOnce( &mut G, &mut CodeGenContext<'ctx, 'a>, - ScalarObject<'ctx>, - ) -> Result, String>, + AnyObject<'ctx>, + ) -> Result, String>, { ScalarOrNDArray::broadcasting_starmap( generator, diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index d6a7acab..397c87f7 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -38,7 +38,7 @@ use inkwell::{ AddressSpace, IntPredicate, }; use nditer::NDIterHandle; -use scalar::{ScalarObject, ScalarOrNDArray}; +use scalar::ScalarOrNDArray; use util::call_memcpy_model; use super::{tuple::TupleObject, AnyObject}; @@ -257,10 +257,10 @@ impl<'ctx> NDArrayObject<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, nth: Int<'ctx, SizeT>, - ) -> ScalarObject<'ctx> { + ) -> AnyObject<'ctx> { let p = self.get_nth_pointer(generator, ctx, nth, "value"); let value = ctx.builder.build_load(p, "value").unwrap(); - ScalarObject { dtype: self.dtype, value } + AnyObject { ty: self.dtype, value } } /// Set the n-th (0-based) scalar. @@ -271,10 +271,10 @@ impl<'ctx> NDArrayObject<'ctx> { generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, nth: Int<'ctx, SizeT>, - scalar: ScalarObject<'ctx>, + scalar: AnyObject<'ctx>, ) { // Sanity check on scalar's `dtype` - assert!(ctx.unifier.unioned(scalar.dtype, self.dtype)); + assert!(ctx.unifier.unioned(scalar.ty, self.dtype)); let pscalar = self.get_nth_pointer(generator, ctx, nth, "pscalar"); ctx.builder.build_store(pscalar, scalar.value).unwrap(); @@ -284,7 +284,7 @@ impl<'ctx> NDArrayObject<'ctx> { /// /// Please refer to the IRRT implementation to see its purpose. pub fn update_strides_by_shape( - &self, + self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, ) { @@ -458,7 +458,7 @@ impl<'ctx> NDArrayObject<'ctx> { self.ndims == 0 } - /// If this ndarray is unsized, return its sole value as a [`ScalarObject`]. Otherwise, do nothing. + /// If this ndarray is unsized, return its sole value as a [`AnyObject`]. Otherwise, do nothing. pub fn split_unsized( &self, generator: &mut G, @@ -604,10 +604,10 @@ impl<'ctx> NDArrayObject<'ctx> { &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - scalar: ScalarObject<'ctx>, + scalar: AnyObject<'ctx>, ) { // Sanity check on scalar's type. - assert!(ctx.unifier.unioned(self.dtype, scalar.dtype)); + assert!(ctx.unifier.unioned(self.dtype, scalar.ty)); self.foreach(generator, ctx, |generator, ctx, _hooks, nditer| { let p = nditer.get_pointer(generator, ctx); diff --git a/nac3core/src/codegen/object/ndarray/nditer.rs b/nac3core/src/codegen/object/ndarray/nditer.rs index f9180cfd..05710003 100644 --- a/nac3core/src/codegen/object/ndarray/nditer.rs +++ b/nac3core/src/codegen/object/ndarray/nditer.rs @@ -3,7 +3,7 @@ use inkwell::{types::BasicType, values::PointerValue, AddressSpace}; use crate::codegen::{ irrt::{call_nac3_nditer_has_next, call_nac3_nditer_initialize, call_nac3_nditer_next}, model::*, - object::ndarray::scalar::ScalarObject, + object::AnyObject, structure::NDIter, CodeGenContext, CodeGenerator, }; @@ -63,10 +63,10 @@ impl<'ctx> NDIterHandle<'ctx> { &self, generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - ) -> ScalarObject<'ctx> { + ) -> AnyObject<'ctx> { let p = self.get_pointer(generator, ctx); let value = ctx.builder.build_load(p, "value").unwrap(); - ScalarObject { dtype: self.ndarray.dtype, value } + AnyObject { ty: self.ndarray.dtype, value } } pub fn get_index( diff --git a/nac3core/src/codegen/object/ndarray/scalar.rs b/nac3core/src/codegen/object/ndarray/scalar.rs index ceeaa273..c7598087 100644 --- a/nac3core/src/codegen/object/ndarray/scalar.rs +++ b/nac3core/src/codegen/object/ndarray/scalar.rs @@ -7,18 +7,7 @@ use crate::{ use super::NDArrayObject; -/// An LLVM numpy scalar with its [`Type`]. -/// -/// Intended to be used with [`ScalarOrNDArray`]. -/// -/// A scalar does not have to be an actual number. It could be arbitrary objects. -#[derive(Debug, Clone, Copy)] -pub struct ScalarObject<'ctx> { - pub dtype: Type, - pub value: BasicValueEnum<'ctx>, -} - -impl<'ctx> ScalarObject<'ctx> { +impl<'ctx> AnyObject<'ctx> { /// Promote this scalar to an unsized ndarray (like doing `np.asarray`). /// /// The scalar value is allocated onto the stack, and the ndarray's `data` will point to that @@ -35,7 +24,7 @@ impl<'ctx> ScalarObject<'ctx> { ctx.builder.build_store(data, self.value).unwrap(); let data = pbyte_model.pointer_cast(generator, ctx, data, "data"); - let ndarray = NDArrayObject::alloca(generator, ctx, self.dtype, 0, "scalar_ndarray"); + let ndarray = NDArrayObject::alloca(generator, ctx, self.ty, 0, "scalar_ndarray"); ndarray.instance.set(ctx, |f| f.data, data); ndarray } @@ -44,7 +33,7 @@ impl<'ctx> ScalarObject<'ctx> { /// A convenience enum for implementing scalar/ndarray agnostic utilities. #[derive(Debug, Clone, Copy)] pub enum ScalarOrNDArray<'ctx> { - Scalar(ScalarObject<'ctx>), + Scalar(AnyObject<'ctx>), NDArray(NDArrayObject<'ctx>), } @@ -59,7 +48,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> { } #[must_use] - pub fn into_scalar(&self) -> ScalarObject<'ctx> { + pub fn into_scalar(&self) -> AnyObject<'ctx> { match self { ScalarOrNDArray::NDArray(_ndarray) => panic!("Got NDArray"), ScalarOrNDArray::Scalar(scalar) => *scalar, @@ -91,13 +80,13 @@ impl<'ctx> ScalarOrNDArray<'ctx> { #[must_use] pub fn dtype(&self) -> Type { match self { - ScalarOrNDArray::Scalar(scalar) => scalar.dtype, + ScalarOrNDArray::Scalar(scalar) => scalar.ty, ScalarOrNDArray::NDArray(ndarray) => ndarray.dtype, } } } -impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for ScalarObject<'ctx> { +impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for AnyObject<'ctx> { type Error = (); fn try_from(value: &ScalarOrNDArray<'ctx>) -> Result { @@ -135,7 +124,7 @@ pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>( ScalarOrNDArray::NDArray(ndarray) } _ => { - let scalar = ScalarObject { dtype: object.ty, value: object.value }; + let scalar = AnyObject { ty: object.ty, value: object.value }; ScalarOrNDArray::Scalar(scalar) } } diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 073a8ff1..b18df8d2 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -1143,12 +1143,12 @@ impl<'a> BuiltinBuilder<'a> { ret_dtype, |generator, ctx, scalar| { let result = match prim { - PrimDef::FunInt32 => scalar.cast_to_int32(generator, ctx), - PrimDef::FunInt64 => scalar.cast_to_int64(generator, ctx), - PrimDef::FunUInt32 => scalar.cast_to_uint32(generator, ctx), - PrimDef::FunUInt64 => scalar.cast_to_uint64(generator, ctx), - PrimDef::FunFloat => scalar.cast_to_float(ctx), - PrimDef::FunBool => scalar.cast_to_bool(ctx), + PrimDef::FunInt32 => scalar.call_int32(generator, ctx), + PrimDef::FunInt64 => scalar.call_int64(generator, ctx), + PrimDef::FunUInt32 => scalar.call_uint32(generator, ctx), + PrimDef::FunUInt64 => scalar.call_uint64(generator, ctx), + PrimDef::FunFloat => scalar.call_float(ctx), + PrimDef::FunBool => scalar.call_bool(ctx), _ => unreachable!(), }; Ok(result.value) @@ -1277,7 +1277,7 @@ impl<'a> BuiltinBuilder<'a> { ctx, int_sized, |generator, ctx, scalar| { - Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value) + Ok(scalar.call_floor_or_ceil(generator, ctx, kind, int_sized).value) }, )?; Ok(Some(result.to_basic_value_enum())) @@ -1927,7 +1927,7 @@ impl<'a> BuiltinBuilder<'a> { let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg = AnyObject { value: arg, ty: arg_ty }; - Ok(Some(arg.len(generator, ctx).value.as_basic_value_enum())) + Ok(Some(arg.call_len(generator, ctx).value.as_basic_value_enum())) }, )))), loc: None,