forked from M-Labs/nac3
WIP: core/ndstrides: checkpoint 13
This commit is contained in:
parent
4b765cfb27
commit
a69a441bdd
|
@ -1555,9 +1555,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
if op.base == Operator::MatMult {
|
if op.base == Operator::MatMult {
|
||||||
// Handle `left @ right`
|
// Handle `left @ right`
|
||||||
let result = NDArrayObject::matmul(generator, ctx, left, right, out)
|
let result = NDArrayObject::matmul(generator, ctx, left, right, out)
|
||||||
.split_unsized(generator, ctx)
|
.split_unsized(generator, ctx);
|
||||||
.to_basic_value_enum();
|
Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())))
|
||||||
Ok(Some(ValueEnum::Dynamic(result)))
|
|
||||||
} else {
|
} else {
|
||||||
// For other operators like +, -, etc...; do them element-wise-ly
|
// For other operators like +, -, etc...; do them element-wise-ly
|
||||||
let result = NDArrayObject::broadcasting_starmap(
|
let result = NDArrayObject::broadcasting_starmap(
|
||||||
|
@ -1566,18 +1565,21 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
&[left, right],
|
&[left, right],
|
||||||
out,
|
out,
|
||||||
|generator, ctx, scalars| {
|
|generator, ctx, scalars| {
|
||||||
let left = scalars[0];
|
let left_scalar = scalars[0];
|
||||||
let right = scalars[1];
|
let right_scalar = scalars[1];
|
||||||
gen_binop_expr_with_values(
|
|
||||||
|
let result = gen_binop_expr_with_values(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
(&Some(left.dtype), left.value),
|
(&Some(left_scalar.ty), left_scalar.value),
|
||||||
op,
|
op,
|
||||||
(&Some(right.dtype), right.value),
|
(&Some(right_scalar.ty), right_scalar.value),
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, common_dtype)
|
.to_basic_value_enum(ctx, generator, common_dtype)?;
|
||||||
|
|
||||||
|
Ok(AnyObject { value: result, ty: common_dtype })
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum())))
|
Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum())))
|
||||||
|
@ -1671,6 +1673,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
/// Generates LLVM IR for a unary operator expression using the [`Type`] and
|
/// Generates LLVM IR for a unary operator expression using the [`Type`] and
|
||||||
/// [LLVM value][`BasicValueEnum`] of the operands.
|
/// [LLVM value][`BasicValueEnum`] of the operands.
|
||||||
|
#[allow(clippy::only_used_in_recursion)]
|
||||||
pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
|
|
@ -63,7 +63,7 @@ impl<'ctx, Element: Model<'ctx>> Ptr<'ctx, ArrayModel<Element>> {
|
||||||
assert!(i < self.model.0.len);
|
assert!(i < self.model.0.len);
|
||||||
|
|
||||||
let zero = ctx.ctx.i32_type().const_zero();
|
let zero = ctx.ctx.i32_type().const_zero();
|
||||||
let i = ctx.ctx.i32_type().const_int(i as u64, false);
|
let i = ctx.ctx.i32_type().const_int(u64::from(i), false);
|
||||||
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
|
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
|
||||||
|
|
||||||
PtrModel(self.model.0.element).check_value(generator, ctx.ctx, ptr).unwrap()
|
PtrModel(self.model.0.element).check_value(generator, ctx.ctx, ptr).unwrap()
|
||||||
|
@ -114,7 +114,7 @@ impl<'ctx, const LEN: u32, Element: Model<'ctx>> Ptr<'ctx, NArrayModel<LEN, Elem
|
||||||
assert!(i < LEN);
|
assert!(i < LEN);
|
||||||
|
|
||||||
let zero = ctx.ctx.i32_type().const_zero();
|
let zero = ctx.ctx.i32_type().const_zero();
|
||||||
let i = ctx.ctx.i32_type().const_int(i as u64, false);
|
let i = ctx.ctx.i32_type().const_int(u64::from(i), false);
|
||||||
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
|
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
|
||||||
|
|
||||||
PtrModel(self.model.0 .0).check_value(generator, ctx.ctx, ptr).unwrap()
|
PtrModel(self.model.0 .0).check_value(generator, ctx.ctx, ptr).unwrap()
|
||||||
|
|
|
@ -86,10 +86,17 @@ impl<'ctx> AnyObject<'ctx> {
|
||||||
matches!(&*ctx.unifier.get_ty(self.ty), TypeEnum::TTuple { .. })
|
matches!(&*ctx.unifier.get_ty(self.ty), TypeEnum::TTuple { .. })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn into_tuple() {}
|
||||||
|
|
||||||
pub fn is_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
pub fn is_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
ctx.unifier.unioned(self.ty, ctx.primitives.int32)
|
ctx.unifier.unioned(self.ty, ctx.primitives.int32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Int<'ctx, Int32> {
|
||||||
|
assert!(self.is_int32(ctx));
|
||||||
|
IntModel(Int32).believe_value(self.value.into_int_value())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn is_uint32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
pub fn is_uint32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
ctx.unifier.unioned(self.ty, ctx.primitives.uint32)
|
ctx.unifier.unioned(self.ty, ctx.primitives.uint32)
|
||||||
}
|
}
|
||||||
|
@ -106,44 +113,33 @@ impl<'ctx> AnyObject<'ctx> {
|
||||||
ctx.unifier.unioned(self.ty, ctx.primitives.bool)
|
ctx.unifier.unioned(self.ty, ctx.primitives.bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns true if the object type is `bool`, `int32`, `int64`, `uint32`, or `uint64`.
|
||||||
pub fn is_int_like(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
pub fn is_int_like(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
ctx.unifier.unioned_any(self.ty, int_like(ctx))
|
ctx.unifier.unioned_any(self.ty, int_like(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns true if the object type is `int32`, `int64`.
|
||||||
pub fn is_signed_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
pub fn is_signed_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
ctx.unifier.unioned_any(self.ty, signed_ints(ctx))
|
ctx.unifier.unioned_any(self.ty, signed_ints(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns true if the object type is `uint32`, `uint64`.
|
||||||
pub fn is_unsigned_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
pub fn is_unsigned_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
ctx.unifier.unioned_any(self.ty, unsigned_ints(ctx))
|
ctx.unifier.unioned_any(self.ty, unsigned_ints(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convenience function. If object has type `int32`, `int64`, `uint32`, `uint64`, or `bool`,
|
|
||||||
/// get its underlying LLVM value.
|
|
||||||
///
|
|
||||||
/// Panic if the type is wrong.
|
|
||||||
pub fn into_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
pub fn into_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||||
if self.is_int_like(ctx) {
|
assert!(self.is_int_like(ctx));
|
||||||
self.value.into_int_value()
|
self.value.into_int_value()
|
||||||
} else {
|
|
||||||
panic!("not an int32 type")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
pub fn is_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
self.is_obj(ctx, PrimDef::Float)
|
self.is_obj(ctx, PrimDef::Float)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convenience function. If object has type `float`, get its underlying LLVM value.
|
|
||||||
///
|
|
||||||
/// Panic if the type is wrong.
|
|
||||||
pub fn into_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Float<'ctx, Float64> {
|
pub fn into_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Float<'ctx, Float64> {
|
||||||
if self.is_float(ctx) {
|
assert!(self.is_float(ctx));
|
||||||
// self.value must be a FloatValue
|
FloatModel(Float64).believe_value(self.value.into_float_value())
|
||||||
FloatModel(Float64).believe_value(self.value.into_float_value())
|
|
||||||
} else {
|
|
||||||
panic!("not a float type")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_ndarray(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
pub fn is_ndarray(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
|
||||||
|
@ -158,6 +154,15 @@ impl<'ctx> AnyObject<'ctx> {
|
||||||
NDArrayObject::from_object(generator, ctx, *self)
|
NDArrayObject::from_object(generator, ctx, *self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create an object from a boolean from an i1.
|
||||||
|
///
|
||||||
|
/// NOTE: In NAC3, booleans are i8. This function does converts the input i1 to an i8.
|
||||||
|
pub fn from_bool(ctx: &mut CodeGenContext<'ctx, '_>, n: Int<'ctx, Bool>) -> AnyObject<'ctx> {
|
||||||
|
let llvm_i8 = ctx.ctx.i8_type();
|
||||||
|
let value = ctx.builder.build_int_z_extend(n.value, llvm_i8, "bool").unwrap();
|
||||||
|
AnyObject { value: value.as_basic_value_enum(), ty: ctx.primitives.bool }
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper function to compare two scalars.
|
/// Helper function to compare two scalars.
|
||||||
///
|
///
|
||||||
/// Only int-to-int and float-to-float comparisons are allowed.
|
/// Only int-to-int and float-to-float comparisons are allowed.
|
||||||
|
@ -172,9 +177,7 @@ impl<'ctx> AnyObject<'ctx> {
|
||||||
float_predicate: FloatPredicate,
|
float_predicate: FloatPredicate,
|
||||||
name: &str,
|
name: &str,
|
||||||
) -> Int<'ctx, Bool> {
|
) -> Int<'ctx, Bool> {
|
||||||
if !ctx.unifier.unioned(lhs.ty, rhs.ty) {
|
assert!(ctx.unifier.unioned(lhs.ty, rhs.ty), "lhs and rhs type should be the same");
|
||||||
panic!("lhs and rhs type are not the same.")
|
|
||||||
}
|
|
||||||
|
|
||||||
let bool_model = IntModel(Bool);
|
let bool_model = IntModel(Bool);
|
||||||
|
|
||||||
|
@ -320,6 +323,7 @@ impl<'ctx> AnyObject<'ctx> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the `len()` of this object.
|
// Get the `len()` of this object.
|
||||||
|
#[must_use]
|
||||||
pub fn call_len<G: CodeGenerator + ?Sized>(
|
pub fn call_len<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -378,20 +382,14 @@ impl<'ctx> AnyObject<'ctx> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Call `bool()` on this object.
|
/// Call `bool()` on this object.
|
||||||
pub fn call_bool<G: CodeGenerator + ?Sized>(
|
#[must_use]
|
||||||
&self,
|
pub fn call_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
) -> AnyObject<'ctx> {
|
|
||||||
let n = self.bool(ctx);
|
let n = self.bool(ctx);
|
||||||
|
AnyObject::from_bool(ctx, n)
|
||||||
// NAC3 booleans are i8
|
|
||||||
let llvm_i8 = ctx.ctx.i8_type();
|
|
||||||
let n = ctx.builder.build_int_z_extend(n.value, llvm_i8, "bool").unwrap();
|
|
||||||
|
|
||||||
AnyObject { ty: ctx.primitives.bool, value: n.as_basic_value_enum() }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Call `float()` on this object.
|
/// Call `float()` on this object.
|
||||||
|
#[must_use]
|
||||||
pub fn call_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
|
pub fn call_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
|
||||||
let f64_model = FloatModel(Float64);
|
let f64_model = FloatModel(Float64);
|
||||||
let llvm_f64 = ctx.ctx.f64_type();
|
let llvm_f64 = ctx.ctx.f64_type();
|
||||||
|
@ -449,6 +447,7 @@ impl<'ctx> AnyObject<'ctx> {
|
||||||
// Call `round()` on this object.
|
// Call `round()` on this object.
|
||||||
//
|
//
|
||||||
// It is possible to specify which kind of int type to return.
|
// It is possible to specify which kind of int type to return.
|
||||||
|
#[must_use]
|
||||||
pub fn call_round<G: CodeGenerator + ?Sized>(
|
pub fn call_round<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -497,6 +496,7 @@ impl<'ctx> AnyObject<'ctx> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Call `min()` or `max()` on two objects.
|
/// Call `min()` or `max()` on two objects.
|
||||||
|
#[must_use]
|
||||||
pub fn call_min_or_max(
|
pub fn call_min_or_max(
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
kind: MinOrMax,
|
kind: MinOrMax,
|
||||||
|
|
|
@ -34,7 +34,7 @@ use indexing::RustNDIndex;
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
context::Context,
|
context::Context,
|
||||||
types::BasicType,
|
types::BasicType,
|
||||||
values::{BasicValue, BasicValueEnum, PointerValue},
|
values::{BasicValue, PointerValue},
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
use nditer::NDIterHandle;
|
use nditer::NDIterHandle;
|
||||||
|
@ -570,7 +570,10 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|
||||||
|generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value),
|
|generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value),
|
||||||
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|
||||||
|generator, ctx, nditer| Ok(nditer.next(generator, ctx)),
|
|generator, ctx, nditer| {
|
||||||
|
nditer.next(generator, ctx);
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ impl<'ctx> NDIterHandle<'ctx> {
|
||||||
NDIterHandle { ndarray, instance: nditer, indices }
|
NDIterHandle { ndarray, instance: nditer, indices }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub fn has_next<G: CodeGenerator + ?Sized>(
|
pub fn has_next<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -43,9 +44,10 @@ impl<'ctx> NDIterHandle<'ctx> {
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
) {
|
) {
|
||||||
call_nac3_nditer_next(generator, ctx, self.instance)
|
call_nac3_nditer_next(generator, ctx, self.instance);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub fn get_pointer<G: CodeGenerator + ?Sized>(
|
pub fn get_pointer<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -59,6 +61,7 @@ impl<'ctx> NDIterHandle<'ctx> {
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub fn get_scalar<G: CodeGenerator + ?Sized>(
|
pub fn get_scalar<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -69,6 +72,7 @@ impl<'ctx> NDIterHandle<'ctx> {
|
||||||
AnyObject { ty: self.ndarray.dtype, value }
|
AnyObject { ty: self.ndarray.dtype, value }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub fn get_index<G: CodeGenerator + ?Sized>(
|
pub fn get_index<G: CodeGenerator + ?Sized>(
|
||||||
&self,
|
&self,
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -77,6 +81,7 @@ impl<'ctx> NDIterHandle<'ctx> {
|
||||||
self.instance.get(generator, ctx, |f| f.nth, "index")
|
self.instance.get(generator, ctx, |f| f.nth, "index")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
pub fn get_indices(&self) -> Ptr<'ctx, IntModel<SizeT>> {
|
pub fn get_indices(&self) -> Ptr<'ctx, IntModel<SizeT>> {
|
||||||
self.indices
|
self.indices
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::{any::Any, iter::once};
|
use std::iter::once;
|
||||||
|
|
||||||
use helper::{
|
use helper::{
|
||||||
create_ndims, debug_assert_prim_is_allowed, extract_ndims, make_exception_fields,
|
create_ndims, debug_assert_prim_is_allowed, extract_ndims, make_exception_fields,
|
||||||
|
@ -26,14 +26,13 @@ use crate::{
|
||||||
numpy_new::{self},
|
numpy_new::{self},
|
||||||
object::{
|
object::{
|
||||||
ndarray::{
|
ndarray::{
|
||||||
functions::{FloorOrCeil, MinOrMax},
|
|
||||||
nalgebra::perform_nalgebra_call,
|
nalgebra::perform_nalgebra_call,
|
||||||
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
|
scalar::{split_scalar_or_ndarray, ScalarOrNDArray},
|
||||||
shape_util::parse_numpy_int_sequence,
|
shape_util::parse_numpy_int_sequence,
|
||||||
NDArrayObject,
|
NDArrayObject,
|
||||||
},
|
},
|
||||||
tuple::TupleObject,
|
tuple::TupleObject,
|
||||||
AnyObject,
|
AnyObject, FloorOrCeil, MinOrMax,
|
||||||
},
|
},
|
||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
},
|
},
|
||||||
|
@ -1077,7 +1076,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
let this = AnyObject { value: this_arg, ty: this_ty };
|
let this = AnyObject { value: this_arg, ty: this_ty };
|
||||||
let this = NDArrayObject::from_object(generator, ctx, this);
|
let this = NDArrayObject::from_object(generator, ctx, this);
|
||||||
|
|
||||||
let value = ScalarObject { value: value_arg, dtype: value_ty };
|
let value = AnyObject { value: value_arg, ty: value_ty };
|
||||||
this.fill(generator, ctx, value);
|
this.fill(generator, ctx, value);
|
||||||
|
|
||||||
Ok(None)
|
Ok(None)
|
||||||
|
@ -1142,7 +1141,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
ctx,
|
ctx,
|
||||||
ret_dtype,
|
ret_dtype,
|
||||||
|generator, ctx, scalar| {
|
|generator, ctx, scalar| {
|
||||||
let result = match prim {
|
Ok(match prim {
|
||||||
PrimDef::FunInt32 => scalar.call_int32(generator, ctx),
|
PrimDef::FunInt32 => scalar.call_int32(generator, ctx),
|
||||||
PrimDef::FunInt64 => scalar.call_int64(generator, ctx),
|
PrimDef::FunInt64 => scalar.call_int64(generator, ctx),
|
||||||
PrimDef::FunUInt32 => scalar.call_uint32(generator, ctx),
|
PrimDef::FunUInt32 => scalar.call_uint32(generator, ctx),
|
||||||
|
@ -1150,8 +1149,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
PrimDef::FunFloat => scalar.call_float(ctx),
|
PrimDef::FunFloat => scalar.call_float(ctx),
|
||||||
PrimDef::FunBool => scalar.call_bool(ctx),
|
PrimDef::FunBool => scalar.call_bool(ctx),
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
})
|
||||||
Ok(result.value)
|
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(result.to_basic_value_enum()))
|
Ok(Some(result.to_basic_value_enum()))
|
||||||
|
@ -1206,13 +1204,13 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
let arg = AnyObject { ty: arg_ty, value: arg };
|
let arg = AnyObject { ty: arg_ty, value: arg };
|
||||||
|
|
||||||
let ret_int_dtype = size_variant.of_int(&ctx.primitives);
|
let ret_int_ty = size_variant.of_int(&ctx.primitives);
|
||||||
|
|
||||||
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ret_int_dtype,
|
ret_int_ty,
|
||||||
|generator, ctx, scalar| Ok(scalar.round(generator, ctx, ret_int_dtype).value),
|
|generator, ctx, scalar| Ok(scalar.call_round(generator, ctx, ret_int_ty)),
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(result.to_basic_value_enum()))
|
Ok(Some(result.to_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
|
@ -1277,7 +1275,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
ctx,
|
ctx,
|
||||||
int_sized,
|
int_sized,
|
||||||
|generator, ctx, scalar| {
|
|generator, ctx, scalar| {
|
||||||
Ok(scalar.call_floor_or_ceil(generator, ctx, kind, int_sized).value)
|
Ok(scalar.call_floor_or_ceil(generator, ctx, kind, int_sized))
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(result.to_basic_value_enum()))
|
Ok(Some(result.to_basic_value_enum()))
|
||||||
|
@ -1477,7 +1475,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
let fill_value_ty = fun.0.args[1].ty;
|
let fill_value_ty = fun.0.args[1].ty;
|
||||||
let fill_value =
|
let fill_value =
|
||||||
args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
|
args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
|
||||||
let fill_value = ScalarObject { dtype: fill_value_ty, value: fill_value };
|
let fill_value = AnyObject { ty: fill_value_ty, value: fill_value };
|
||||||
|
|
||||||
// Implementation
|
// Implementation
|
||||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||||
|
@ -1858,10 +1856,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
move |_generator, ctx, scalar| {
|
|generator, ctx, scalar| Ok(scalar.call_np_floor_or_ceil(generator, ctx, kind)),
|
||||||
let result = scalar.np_floor_or_ceil(ctx, kind);
|
|
||||||
Ok(result.value)
|
|
||||||
},
|
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(result.to_basic_value_enum()))
|
Ok(Some(result.to_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
|
@ -1887,10 +1882,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
|_generator, ctx, scalar| {
|
|generator, ctx, scalar| Ok(scalar.call_np_round(generator, ctx)),
|
||||||
let result = scalar.np_round(ctx);
|
|
||||||
Ok(result.value)
|
|
||||||
},
|
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(result.to_basic_value_enum()))
|
Ok(Some(result.to_basic_value_enum()))
|
||||||
}),
|
}),
|
||||||
|
@ -1977,9 +1969,9 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let m = ScalarObject { dtype: m_ty, value: m_val };
|
let m = AnyObject { ty: m_ty, value: m_val };
|
||||||
let n = ScalarObject { dtype: n_ty, value: n_val };
|
let n = AnyObject { ty: n_ty, value: n_val };
|
||||||
let result = ScalarObject::min_or_max(ctx, kind, m, n);
|
let result = AnyObject::call_min_or_max(ctx, kind, m, n);
|
||||||
Ok(Some(result.value))
|
Ok(Some(result.value))
|
||||||
},
|
},
|
||||||
)))),
|
)))),
|
||||||
|
@ -2104,9 +2096,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
|_generator, ctx, scalars| {
|
|_generator, ctx, scalars| {
|
||||||
let x1 = scalars[0];
|
let x1 = scalars[0];
|
||||||
let x2 = scalars[1];
|
let x2 = scalars[1];
|
||||||
|
Ok(AnyObject::call_min_or_max(ctx, kind, x1, x2))
|
||||||
let result = ScalarObject::min_or_max(ctx, kind, x1, x2);
|
|
||||||
Ok(result.value)
|
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(result.to_basic_value_enum()))
|
Ok(Some(result.to_basic_value_enum()))
|
||||||
|
@ -2148,7 +2138,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
num_ty.ty,
|
num_ty.ty,
|
||||||
|_generator, ctx, scalar| Ok(scalar.abs(ctx).value),
|
|generator, ctx, scalar| Ok(scalar.call_abs(generator, ctx)),
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(result.to_basic_value_enum()))
|
Ok(Some(result.to_basic_value_enum()))
|
||||||
},
|
},
|
||||||
|
@ -2185,9 +2175,10 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.bool,
|
ctx.primitives.bool,
|
||||||
|generator, ctx, scalar| {
|
|generator, ctx, scalar| {
|
||||||
let n = scalar.into_float64(ctx);
|
let n = scalar.into_float(ctx).value;
|
||||||
let n = function(generator, ctx, n);
|
let n = function(generator, ctx, n);
|
||||||
Ok(n.as_basic_value_enum())
|
let n = IntModel(Bool).believe_value(n);
|
||||||
|
Ok(AnyObject::from_bool(ctx, n))
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(result.to_basic_value_enum()))
|
Ok(Some(result.to_basic_value_enum()))
|
||||||
|
@ -2254,8 +2245,9 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
|_generator, ctx, scalar| {
|
|_generator, ctx, scalar| {
|
||||||
let n = scalar.into_float64(ctx);
|
let n = scalar.into_float(ctx).value;
|
||||||
let n = match prim {
|
|
||||||
|
let result = match prim {
|
||||||
PrimDef::FunNpSin => llvm_intrinsics::call_float_sin(ctx, n, None),
|
PrimDef::FunNpSin => llvm_intrinsics::call_float_sin(ctx, n, None),
|
||||||
PrimDef::FunNpCos => llvm_intrinsics::call_float_cos(ctx, n, None),
|
PrimDef::FunNpCos => llvm_intrinsics::call_float_cos(ctx, n, None),
|
||||||
PrimDef::FunNpTan => extern_fns::call_tan(ctx, n, None),
|
PrimDef::FunNpTan => extern_fns::call_tan(ctx, n, None),
|
||||||
|
@ -2297,7 +2289,11 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
|
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
Ok(n.as_basic_value_enum())
|
|
||||||
|
Ok(AnyObject {
|
||||||
|
ty: ctx.primitives.float,
|
||||||
|
value: result.as_basic_value_enum(),
|
||||||
|
})
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
Ok(Some(result.to_basic_value_enum()))
|
Ok(Some(result.to_basic_value_enum()))
|
||||||
|
@ -2385,48 +2381,48 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
// TODO: This looks ugly
|
// TODO: This looks ugly
|
||||||
let result = match prim {
|
let result = match prim {
|
||||||
PrimDef::FunNpArctan2 => {
|
PrimDef::FunNpArctan2 => {
|
||||||
let x1 = x1.into_float64(ctx);
|
let x1 = x1.into_float(ctx).value;
|
||||||
let x2 = x2.into_float64(ctx);
|
let x2 = x2.into_float(ctx).value;
|
||||||
extern_fns::call_atan2(ctx, x1, x2, None).as_basic_value_enum()
|
extern_fns::call_atan2(ctx, x1, x2, None).as_basic_value_enum()
|
||||||
}
|
}
|
||||||
PrimDef::FunNpCopysign => {
|
PrimDef::FunNpCopysign => {
|
||||||
let x1 = x1.into_float64(ctx);
|
let x1 = x1.into_float(ctx).value;
|
||||||
let x2 = x2.into_float64(ctx);
|
let x2 = x2.into_float(ctx).value;
|
||||||
llvm_intrinsics::call_float_copysign(ctx, x1, x2, None)
|
llvm_intrinsics::call_float_copysign(ctx, x1, x2, None)
|
||||||
.as_basic_value_enum()
|
.as_basic_value_enum()
|
||||||
}
|
}
|
||||||
PrimDef::FunNpFmax => {
|
PrimDef::FunNpFmax => {
|
||||||
let x1 = x1.into_float64(ctx);
|
let x1 = x1.into_float(ctx).value;
|
||||||
let x2 = x2.into_float64(ctx);
|
let x2 = x2.into_float(ctx).value;
|
||||||
llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None)
|
llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None)
|
||||||
.as_basic_value_enum()
|
.as_basic_value_enum()
|
||||||
}
|
}
|
||||||
PrimDef::FunNpFmin => {
|
PrimDef::FunNpFmin => {
|
||||||
let x1 = x1.into_float64(ctx);
|
let x1 = x1.into_float(ctx).value;
|
||||||
let x2 = x2.into_float64(ctx);
|
let x2 = x2.into_float(ctx).value;
|
||||||
llvm_intrinsics::call_float_minnum(ctx, x1, x2, None)
|
llvm_intrinsics::call_float_minnum(ctx, x1, x2, None)
|
||||||
.as_basic_value_enum()
|
.as_basic_value_enum()
|
||||||
}
|
}
|
||||||
PrimDef::FunNpHypot => {
|
PrimDef::FunNpHypot => {
|
||||||
let x1 = x1.into_float64(ctx);
|
let x1 = x1.into_float(ctx).value;
|
||||||
let x2 = x2.into_float64(ctx);
|
let x2 = x2.into_float(ctx).value;
|
||||||
extern_fns::call_hypot(ctx, x1, x2, None).as_basic_value_enum()
|
extern_fns::call_hypot(ctx, x1, x2, None).as_basic_value_enum()
|
||||||
}
|
}
|
||||||
PrimDef::FunNpNextAfter => {
|
PrimDef::FunNpNextAfter => {
|
||||||
let x1 = x1.into_float64(ctx);
|
let x1 = x1.into_float(ctx).value;
|
||||||
let x2 = x2.into_float64(ctx);
|
let x2 = x2.into_float(ctx).value;
|
||||||
extern_fns::call_nextafter(ctx, x1, x2, None)
|
extern_fns::call_nextafter(ctx, x1, x2, None)
|
||||||
.as_basic_value_enum()
|
.as_basic_value_enum()
|
||||||
}
|
}
|
||||||
PrimDef::FunNpLdExp => {
|
PrimDef::FunNpLdExp => {
|
||||||
let x1 = x1.into_float64(ctx);
|
let x1 = x1.into_float(ctx).value;
|
||||||
let x2 = x2.into_int32(ctx);
|
let x2 = x2.into_int32(ctx).value;
|
||||||
extern_fns::call_ldexp(ctx, x1, x2, None).as_basic_value_enum()
|
extern_fns::call_ldexp(ctx, x1, x2, None).as_basic_value_enum()
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(result)
|
Ok(AnyObject { ty: ret_dtype, value: result })
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
@ -2628,13 +2624,9 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
zero,
|
zero,
|
||||||
ScalarObject { dtype: x2_ty, value: x2_val },
|
AnyObject { ty: x2_ty, value: x2_val },
|
||||||
);
|
);
|
||||||
|
|
||||||
// alloca_constant_shape
|
|
||||||
|
|
||||||
// let x2 = x2.as_ndarray(generator, ctx);
|
|
||||||
|
|
||||||
let [out] = perform_nalgebra_call(
|
let [out] = perform_nalgebra_call(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
|
|
Loading…
Reference in New Issue