forked from M-Labs/nac3
1
0
Fork 0

WIP: core/ndstrides: checkpoint 13

This commit is contained in:
lyken 2024-08-15 14:38:05 +08:00
parent 4b765cfb27
commit a69a441bdd
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
6 changed files with 101 additions and 98 deletions

View File

@ -1555,9 +1555,8 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
if op.base == Operator::MatMult { if op.base == Operator::MatMult {
// Handle `left @ right` // Handle `left @ right`
let result = NDArrayObject::matmul(generator, ctx, left, right, out) let result = NDArrayObject::matmul(generator, ctx, left, right, out)
.split_unsized(generator, ctx) .split_unsized(generator, ctx);
.to_basic_value_enum(); Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum())))
Ok(Some(ValueEnum::Dynamic(result)))
} else { } else {
// For other operators like +, -, etc...; do them element-wise-ly // For other operators like +, -, etc...; do them element-wise-ly
let result = NDArrayObject::broadcasting_starmap( let result = NDArrayObject::broadcasting_starmap(
@ -1566,18 +1565,21 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
&[left, right], &[left, right],
out, out,
|generator, ctx, scalars| { |generator, ctx, scalars| {
let left = scalars[0]; let left_scalar = scalars[0];
let right = scalars[1]; let right_scalar = scalars[1];
gen_binop_expr_with_values(
let result = gen_binop_expr_with_values(
generator, generator,
ctx, ctx,
(&Some(left.dtype), left.value), (&Some(left_scalar.ty), left_scalar.value),
op, op,
(&Some(right.dtype), right.value), (&Some(right_scalar.ty), right_scalar.value),
ctx.current_loc, ctx.current_loc,
)? )?
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, common_dtype) .to_basic_value_enum(ctx, generator, common_dtype)?;
Ok(AnyObject { value: result, ty: common_dtype })
}, },
)?; )?;
Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum()))) Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum())))
@ -1671,6 +1673,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
/// Generates LLVM IR for a unary operator expression using the [`Type`] and /// Generates LLVM IR for a unary operator expression using the [`Type`] and
/// [LLVM value][`BasicValueEnum`] of the operands. /// [LLVM value][`BasicValueEnum`] of the operands.
#[allow(clippy::only_used_in_recursion)]
pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,

View File

@ -63,7 +63,7 @@ impl<'ctx, Element: Model<'ctx>> Ptr<'ctx, ArrayModel<Element>> {
assert!(i < self.model.0.len); assert!(i < self.model.0.len);
let zero = ctx.ctx.i32_type().const_zero(); let zero = ctx.ctx.i32_type().const_zero();
let i = ctx.ctx.i32_type().const_int(i as u64, false); let i = ctx.ctx.i32_type().const_int(u64::from(i), false);
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() }; let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
PtrModel(self.model.0.element).check_value(generator, ctx.ctx, ptr).unwrap() PtrModel(self.model.0.element).check_value(generator, ctx.ctx, ptr).unwrap()
@ -114,7 +114,7 @@ impl<'ctx, const LEN: u32, Element: Model<'ctx>> Ptr<'ctx, NArrayModel<LEN, Elem
assert!(i < LEN); assert!(i < LEN);
let zero = ctx.ctx.i32_type().const_zero(); let zero = ctx.ctx.i32_type().const_zero();
let i = ctx.ctx.i32_type().const_int(i as u64, false); let i = ctx.ctx.i32_type().const_int(u64::from(i), false);
let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() }; let ptr = unsafe { ctx.builder.build_in_bounds_gep(self.value, &[zero, i], name).unwrap() };
PtrModel(self.model.0 .0).check_value(generator, ctx.ctx, ptr).unwrap() PtrModel(self.model.0 .0).check_value(generator, ctx.ctx, ptr).unwrap()

View File

@ -86,10 +86,17 @@ impl<'ctx> AnyObject<'ctx> {
matches!(&*ctx.unifier.get_ty(self.ty), TypeEnum::TTuple { .. }) matches!(&*ctx.unifier.get_ty(self.ty), TypeEnum::TTuple { .. })
} }
pub fn into_tuple() {}
pub fn is_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { pub fn is_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.int32) ctx.unifier.unioned(self.ty, ctx.primitives.int32)
} }
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Int<'ctx, Int32> {
assert!(self.is_int32(ctx));
IntModel(Int32).believe_value(self.value.into_int_value())
}
pub fn is_uint32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { pub fn is_uint32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned(self.ty, ctx.primitives.uint32) ctx.unifier.unioned(self.ty, ctx.primitives.uint32)
} }
@ -106,44 +113,33 @@ impl<'ctx> AnyObject<'ctx> {
ctx.unifier.unioned(self.ty, ctx.primitives.bool) ctx.unifier.unioned(self.ty, ctx.primitives.bool)
} }
/// Returns true if the object type is `bool`, `int32`, `int64`, `uint32`, or `uint64`.
pub fn is_int_like(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { pub fn is_int_like(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, int_like(ctx)) ctx.unifier.unioned_any(self.ty, int_like(ctx))
} }
/// Returns true if the object type is `int32`, `int64`.
pub fn is_signed_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { pub fn is_signed_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, signed_ints(ctx)) ctx.unifier.unioned_any(self.ty, signed_ints(ctx))
} }
/// Returns true if the object type is `uint32`, `uint64`.
pub fn is_unsigned_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { pub fn is_unsigned_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
ctx.unifier.unioned_any(self.ty, unsigned_ints(ctx)) ctx.unifier.unioned_any(self.ty, unsigned_ints(ctx))
} }
/// Convenience function. If object has type `int32`, `int64`, `uint32`, `uint64`, or `bool`,
/// get its underlying LLVM value.
///
/// Panic if the type is wrong.
pub fn into_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> { pub fn into_int(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
if self.is_int_like(ctx) { assert!(self.is_int_like(ctx));
self.value.into_int_value() self.value.into_int_value()
} else {
panic!("not an int32 type")
}
} }
pub fn is_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { pub fn is_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
self.is_obj(ctx, PrimDef::Float) self.is_obj(ctx, PrimDef::Float)
} }
/// Convenience function. If object has type `float`, get its underlying LLVM value.
///
/// Panic if the type is wrong.
pub fn into_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Float<'ctx, Float64> { pub fn into_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Float<'ctx, Float64> {
if self.is_float(ctx) { assert!(self.is_float(ctx));
// self.value must be a FloatValue
FloatModel(Float64).believe_value(self.value.into_float_value()) FloatModel(Float64).believe_value(self.value.into_float_value())
} else {
panic!("not a float type")
}
} }
pub fn is_ndarray(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool { pub fn is_ndarray(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> bool {
@ -158,6 +154,15 @@ impl<'ctx> AnyObject<'ctx> {
NDArrayObject::from_object(generator, ctx, *self) NDArrayObject::from_object(generator, ctx, *self)
} }
/// Create an object from a boolean from an i1.
///
/// NOTE: In NAC3, booleans are i8. This function does converts the input i1 to an i8.
pub fn from_bool(ctx: &mut CodeGenContext<'ctx, '_>, n: Int<'ctx, Bool>) -> AnyObject<'ctx> {
let llvm_i8 = ctx.ctx.i8_type();
let value = ctx.builder.build_int_z_extend(n.value, llvm_i8, "bool").unwrap();
AnyObject { value: value.as_basic_value_enum(), ty: ctx.primitives.bool }
}
/// Helper function to compare two scalars. /// Helper function to compare two scalars.
/// ///
/// Only int-to-int and float-to-float comparisons are allowed. /// Only int-to-int and float-to-float comparisons are allowed.
@ -172,9 +177,7 @@ impl<'ctx> AnyObject<'ctx> {
float_predicate: FloatPredicate, float_predicate: FloatPredicate,
name: &str, name: &str,
) -> Int<'ctx, Bool> { ) -> Int<'ctx, Bool> {
if !ctx.unifier.unioned(lhs.ty, rhs.ty) { assert!(ctx.unifier.unioned(lhs.ty, rhs.ty), "lhs and rhs type should be the same");
panic!("lhs and rhs type are not the same.")
}
let bool_model = IntModel(Bool); let bool_model = IntModel(Bool);
@ -320,6 +323,7 @@ impl<'ctx> AnyObject<'ctx> {
} }
// Get the `len()` of this object. // Get the `len()` of this object.
#[must_use]
pub fn call_len<G: CodeGenerator + ?Sized>( pub fn call_len<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
@ -378,20 +382,14 @@ impl<'ctx> AnyObject<'ctx> {
} }
/// Call `bool()` on this object. /// Call `bool()` on this object.
pub fn call_bool<G: CodeGenerator + ?Sized>( #[must_use]
&self, pub fn call_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
ctx: &mut CodeGenContext<'ctx, '_>,
) -> AnyObject<'ctx> {
let n = self.bool(ctx); let n = self.bool(ctx);
AnyObject::from_bool(ctx, n)
// NAC3 booleans are i8
let llvm_i8 = ctx.ctx.i8_type();
let n = ctx.builder.build_int_z_extend(n.value, llvm_i8, "bool").unwrap();
AnyObject { ty: ctx.primitives.bool, value: n.as_basic_value_enum() }
} }
/// Call `float()` on this object. /// Call `float()` on this object.
#[must_use]
pub fn call_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> { pub fn call_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
let f64_model = FloatModel(Float64); let f64_model = FloatModel(Float64);
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
@ -449,6 +447,7 @@ impl<'ctx> AnyObject<'ctx> {
// Call `round()` on this object. // Call `round()` on this object.
// //
// It is possible to specify which kind of int type to return. // It is possible to specify which kind of int type to return.
#[must_use]
pub fn call_round<G: CodeGenerator + ?Sized>( pub fn call_round<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
@ -497,6 +496,7 @@ impl<'ctx> AnyObject<'ctx> {
} }
/// Call `min()` or `max()` on two objects. /// Call `min()` or `max()` on two objects.
#[must_use]
pub fn call_min_or_max( pub fn call_min_or_max(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax, kind: MinOrMax,

View File

@ -34,7 +34,7 @@ use indexing::RustNDIndex;
use inkwell::{ use inkwell::{
context::Context, context::Context,
types::BasicType, types::BasicType,
values::{BasicValue, BasicValueEnum, PointerValue}, values::{BasicValue, PointerValue},
AddressSpace, IntPredicate, AddressSpace, IntPredicate,
}; };
use nditer::NDIterHandle; use nditer::NDIterHandle;
@ -570,7 +570,10 @@ impl<'ctx> NDArrayObject<'ctx> {
|generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)), |generator, ctx| Ok(NDIterHandle::new(generator, ctx, *self)),
|generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value), |generator, ctx, nditer| Ok(nditer.has_next(generator, ctx).value),
|generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer), |generator, ctx, hooks, nditer| body(generator, ctx, hooks, nditer),
|generator, ctx, nditer| Ok(nditer.next(generator, ctx)), |generator, ctx, nditer| {
nditer.next(generator, ctx);
Ok(())
},
) )
} }

View File

@ -30,6 +30,7 @@ impl<'ctx> NDIterHandle<'ctx> {
NDIterHandle { ndarray, instance: nditer, indices } NDIterHandle { ndarray, instance: nditer, indices }
} }
#[must_use]
pub fn has_next<G: CodeGenerator + ?Sized>( pub fn has_next<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
@ -43,9 +44,10 @@ impl<'ctx> NDIterHandle<'ctx> {
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
) { ) {
call_nac3_nditer_next(generator, ctx, self.instance) call_nac3_nditer_next(generator, ctx, self.instance);
} }
#[must_use]
pub fn get_pointer<G: CodeGenerator + ?Sized>( pub fn get_pointer<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
@ -59,6 +61,7 @@ impl<'ctx> NDIterHandle<'ctx> {
.unwrap() .unwrap()
} }
#[must_use]
pub fn get_scalar<G: CodeGenerator + ?Sized>( pub fn get_scalar<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
@ -69,6 +72,7 @@ impl<'ctx> NDIterHandle<'ctx> {
AnyObject { ty: self.ndarray.dtype, value } AnyObject { ty: self.ndarray.dtype, value }
} }
#[must_use]
pub fn get_index<G: CodeGenerator + ?Sized>( pub fn get_index<G: CodeGenerator + ?Sized>(
&self, &self,
generator: &mut G, generator: &mut G,
@ -77,6 +81,7 @@ impl<'ctx> NDIterHandle<'ctx> {
self.instance.get(generator, ctx, |f| f.nth, "index") self.instance.get(generator, ctx, |f| f.nth, "index")
} }
#[must_use]
pub fn get_indices(&self) -> Ptr<'ctx, IntModel<SizeT>> { pub fn get_indices(&self) -> Ptr<'ctx, IntModel<SizeT>> {
self.indices self.indices
} }

View File

@ -1,4 +1,4 @@
use std::{any::Any, iter::once}; use std::iter::once;
use helper::{ use helper::{
create_ndims, debug_assert_prim_is_allowed, extract_ndims, make_exception_fields, create_ndims, debug_assert_prim_is_allowed, extract_ndims, make_exception_fields,
@ -26,14 +26,13 @@ use crate::{
numpy_new::{self}, numpy_new::{self},
object::{ object::{
ndarray::{ ndarray::{
functions::{FloorOrCeil, MinOrMax},
nalgebra::perform_nalgebra_call, nalgebra::perform_nalgebra_call,
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray}, scalar::{split_scalar_or_ndarray, ScalarOrNDArray},
shape_util::parse_numpy_int_sequence, shape_util::parse_numpy_int_sequence,
NDArrayObject, NDArrayObject,
}, },
tuple::TupleObject, tuple::TupleObject,
AnyObject, AnyObject, FloorOrCeil, MinOrMax,
}, },
stmt::exn_constructor, stmt::exn_constructor,
}, },
@ -1077,7 +1076,7 @@ impl<'a> BuiltinBuilder<'a> {
let this = AnyObject { value: this_arg, ty: this_ty }; let this = AnyObject { value: this_arg, ty: this_ty };
let this = NDArrayObject::from_object(generator, ctx, this); let this = NDArrayObject::from_object(generator, ctx, this);
let value = ScalarObject { value: value_arg, dtype: value_ty }; let value = AnyObject { value: value_arg, ty: value_ty };
this.fill(generator, ctx, value); this.fill(generator, ctx, value);
Ok(None) Ok(None)
@ -1142,7 +1141,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx, ctx,
ret_dtype, ret_dtype,
|generator, ctx, scalar| { |generator, ctx, scalar| {
let result = match prim { Ok(match prim {
PrimDef::FunInt32 => scalar.call_int32(generator, ctx), PrimDef::FunInt32 => scalar.call_int32(generator, ctx),
PrimDef::FunInt64 => scalar.call_int64(generator, ctx), PrimDef::FunInt64 => scalar.call_int64(generator, ctx),
PrimDef::FunUInt32 => scalar.call_uint32(generator, ctx), PrimDef::FunUInt32 => scalar.call_uint32(generator, ctx),
@ -1150,8 +1149,7 @@ impl<'a> BuiltinBuilder<'a> {
PrimDef::FunFloat => scalar.call_float(ctx), PrimDef::FunFloat => scalar.call_float(ctx),
PrimDef::FunBool => scalar.call_bool(ctx), PrimDef::FunBool => scalar.call_bool(ctx),
_ => unreachable!(), _ => unreachable!(),
}; })
Ok(result.value)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1206,13 +1204,13 @@ impl<'a> BuiltinBuilder<'a> {
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let arg = AnyObject { ty: arg_ty, value: arg }; let arg = AnyObject { ty: arg_ty, value: arg };
let ret_int_dtype = size_variant.of_int(&ctx.primitives); let ret_int_ty = size_variant.of_int(&ctx.primitives);
let result = split_scalar_or_ndarray(generator, ctx, arg).map( let result = split_scalar_or_ndarray(generator, ctx, arg).map(
generator, generator,
ctx, ctx,
ret_int_dtype, ret_int_ty,
|generator, ctx, scalar| Ok(scalar.round(generator, ctx, ret_int_dtype).value), |generator, ctx, scalar| Ok(scalar.call_round(generator, ctx, ret_int_ty)),
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
}), }),
@ -1277,7 +1275,7 @@ impl<'a> BuiltinBuilder<'a> {
ctx, ctx,
int_sized, int_sized,
|generator, ctx, scalar| { |generator, ctx, scalar| {
Ok(scalar.call_floor_or_ceil(generator, ctx, kind, int_sized).value) Ok(scalar.call_floor_or_ceil(generator, ctx, kind, int_sized))
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -1477,7 +1475,7 @@ impl<'a> BuiltinBuilder<'a> {
let fill_value_ty = fun.0.args[1].ty; let fill_value_ty = fun.0.args[1].ty;
let fill_value = let fill_value =
args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?; args[1].1.clone().to_basic_value_enum(ctx, generator, fill_value_ty)?;
let fill_value = ScalarObject { dtype: fill_value_ty, value: fill_value }; let fill_value = AnyObject { ty: fill_value_ty, value: fill_value };
// Implementation // Implementation
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret); let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
@ -1858,10 +1856,7 @@ impl<'a> BuiltinBuilder<'a> {
generator, generator,
ctx, ctx,
ctx.primitives.float, ctx.primitives.float,
move |_generator, ctx, scalar| { |generator, ctx, scalar| Ok(scalar.call_np_floor_or_ceil(generator, ctx, kind)),
let result = scalar.np_floor_or_ceil(ctx, kind);
Ok(result.value)
},
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
}), }),
@ -1887,10 +1882,7 @@ impl<'a> BuiltinBuilder<'a> {
generator, generator,
ctx, ctx,
ctx.primitives.float, ctx.primitives.float,
|_generator, ctx, scalar| { |generator, ctx, scalar| Ok(scalar.call_np_round(generator, ctx)),
let result = scalar.np_round(ctx);
Ok(result.value)
},
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
}), }),
@ -1977,9 +1969,9 @@ impl<'a> BuiltinBuilder<'a> {
_ => unreachable!(), _ => unreachable!(),
}; };
let m = ScalarObject { dtype: m_ty, value: m_val }; let m = AnyObject { ty: m_ty, value: m_val };
let n = ScalarObject { dtype: n_ty, value: n_val }; let n = AnyObject { ty: n_ty, value: n_val };
let result = ScalarObject::min_or_max(ctx, kind, m, n); let result = AnyObject::call_min_or_max(ctx, kind, m, n);
Ok(Some(result.value)) Ok(Some(result.value))
}, },
)))), )))),
@ -2104,9 +2096,7 @@ impl<'a> BuiltinBuilder<'a> {
|_generator, ctx, scalars| { |_generator, ctx, scalars| {
let x1 = scalars[0]; let x1 = scalars[0];
let x2 = scalars[1]; let x2 = scalars[1];
Ok(AnyObject::call_min_or_max(ctx, kind, x1, x2))
let result = ScalarObject::min_or_max(ctx, kind, x1, x2);
Ok(result.value)
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -2148,7 +2138,7 @@ impl<'a> BuiltinBuilder<'a> {
generator, generator,
ctx, ctx,
num_ty.ty, num_ty.ty,
|_generator, ctx, scalar| Ok(scalar.abs(ctx).value), |generator, ctx, scalar| Ok(scalar.call_abs(generator, ctx)),
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
}, },
@ -2185,9 +2175,10 @@ impl<'a> BuiltinBuilder<'a> {
ctx, ctx,
ctx.primitives.bool, ctx.primitives.bool,
|generator, ctx, scalar| { |generator, ctx, scalar| {
let n = scalar.into_float64(ctx); let n = scalar.into_float(ctx).value;
let n = function(generator, ctx, n); let n = function(generator, ctx, n);
Ok(n.as_basic_value_enum()) let n = IntModel(Bool).believe_value(n);
Ok(AnyObject::from_bool(ctx, n))
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -2254,8 +2245,9 @@ impl<'a> BuiltinBuilder<'a> {
ctx, ctx,
ctx.primitives.float, ctx.primitives.float,
|_generator, ctx, scalar| { |_generator, ctx, scalar| {
let n = scalar.into_float64(ctx); let n = scalar.into_float(ctx).value;
let n = match prim {
let result = match prim {
PrimDef::FunNpSin => llvm_intrinsics::call_float_sin(ctx, n, None), PrimDef::FunNpSin => llvm_intrinsics::call_float_sin(ctx, n, None),
PrimDef::FunNpCos => llvm_intrinsics::call_float_cos(ctx, n, None), PrimDef::FunNpCos => llvm_intrinsics::call_float_cos(ctx, n, None),
PrimDef::FunNpTan => extern_fns::call_tan(ctx, n, None), PrimDef::FunNpTan => extern_fns::call_tan(ctx, n, None),
@ -2297,7 +2289,11 @@ impl<'a> BuiltinBuilder<'a> {
_ => unreachable!(), _ => unreachable!(),
}; };
Ok(n.as_basic_value_enum())
Ok(AnyObject {
ty: ctx.primitives.float,
value: result.as_basic_value_enum(),
})
}, },
)?; )?;
Ok(Some(result.to_basic_value_enum())) Ok(Some(result.to_basic_value_enum()))
@ -2385,48 +2381,48 @@ impl<'a> BuiltinBuilder<'a> {
// TODO: This looks ugly // TODO: This looks ugly
let result = match prim { let result = match prim {
PrimDef::FunNpArctan2 => { PrimDef::FunNpArctan2 => {
let x1 = x1.into_float64(ctx); let x1 = x1.into_float(ctx).value;
let x2 = x2.into_float64(ctx); let x2 = x2.into_float(ctx).value;
extern_fns::call_atan2(ctx, x1, x2, None).as_basic_value_enum() extern_fns::call_atan2(ctx, x1, x2, None).as_basic_value_enum()
} }
PrimDef::FunNpCopysign => { PrimDef::FunNpCopysign => {
let x1 = x1.into_float64(ctx); let x1 = x1.into_float(ctx).value;
let x2 = x2.into_float64(ctx); let x2 = x2.into_float(ctx).value;
llvm_intrinsics::call_float_copysign(ctx, x1, x2, None) llvm_intrinsics::call_float_copysign(ctx, x1, x2, None)
.as_basic_value_enum() .as_basic_value_enum()
} }
PrimDef::FunNpFmax => { PrimDef::FunNpFmax => {
let x1 = x1.into_float64(ctx); let x1 = x1.into_float(ctx).value;
let x2 = x2.into_float64(ctx); let x2 = x2.into_float(ctx).value;
llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None) llvm_intrinsics::call_float_maxnum(ctx, x1, x2, None)
.as_basic_value_enum() .as_basic_value_enum()
} }
PrimDef::FunNpFmin => { PrimDef::FunNpFmin => {
let x1 = x1.into_float64(ctx); let x1 = x1.into_float(ctx).value;
let x2 = x2.into_float64(ctx); let x2 = x2.into_float(ctx).value;
llvm_intrinsics::call_float_minnum(ctx, x1, x2, None) llvm_intrinsics::call_float_minnum(ctx, x1, x2, None)
.as_basic_value_enum() .as_basic_value_enum()
} }
PrimDef::FunNpHypot => { PrimDef::FunNpHypot => {
let x1 = x1.into_float64(ctx); let x1 = x1.into_float(ctx).value;
let x2 = x2.into_float64(ctx); let x2 = x2.into_float(ctx).value;
extern_fns::call_hypot(ctx, x1, x2, None).as_basic_value_enum() extern_fns::call_hypot(ctx, x1, x2, None).as_basic_value_enum()
} }
PrimDef::FunNpNextAfter => { PrimDef::FunNpNextAfter => {
let x1 = x1.into_float64(ctx); let x1 = x1.into_float(ctx).value;
let x2 = x2.into_float64(ctx); let x2 = x2.into_float(ctx).value;
extern_fns::call_nextafter(ctx, x1, x2, None) extern_fns::call_nextafter(ctx, x1, x2, None)
.as_basic_value_enum() .as_basic_value_enum()
} }
PrimDef::FunNpLdExp => { PrimDef::FunNpLdExp => {
let x1 = x1.into_float64(ctx); let x1 = x1.into_float(ctx).value;
let x2 = x2.into_int32(ctx); let x2 = x2.into_int32(ctx).value;
extern_fns::call_ldexp(ctx, x1, x2, None).as_basic_value_enum() extern_fns::call_ldexp(ctx, x1, x2, None).as_basic_value_enum()
} }
_ => unreachable!(), _ => unreachable!(),
}; };
Ok(result) Ok(AnyObject { ty: ret_dtype, value: result })
}, },
)?; )?;
@ -2628,13 +2624,9 @@ impl<'a> BuiltinBuilder<'a> {
generator, generator,
ctx, ctx,
zero, zero,
ScalarObject { dtype: x2_ty, value: x2_val }, AnyObject { ty: x2_ty, value: x2_val },
); );
// alloca_constant_shape
// let x2 = x2.as_ndarray(generator, ctx);
let [out] = perform_nalgebra_call( let [out] = perform_nalgebra_call(
generator, generator,
ctx, ctx,