forked from M-Labs/nac3
core/ndstrides: checkpoint 3
This commit is contained in:
parent
937b36dcfd
commit
858b4b9f3f
File diff suppressed because it is too large
Load Diff
|
@ -15,11 +15,11 @@ use crate::{
|
|||
},
|
||||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{numpy::unpack_ndarray_var_tys, DefinitionId},
|
||||
typecheck::{
|
||||
numpy::extract_ndims,
|
||||
typedef::{FunSignature, Type},
|
||||
toplevel::{
|
||||
numpy::{extract_ndims, unpack_ndarray_var_tys},
|
||||
DefinitionId,
|
||||
},
|
||||
typecheck::typedef::{FunSignature, Type},
|
||||
};
|
||||
|
||||
use super::{
|
||||
|
|
|
@ -68,7 +68,7 @@ fn cast_to_int_conversion<'ctx, 'a, G, HandleFloatFn>(
|
|||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
scalar: ScalarObject<'ctx>,
|
||||
target_int_dtype: Type,
|
||||
ret_int_dtype: Type,
|
||||
handle_float: HandleFloatFn,
|
||||
) -> ScalarObject<'ctx>
|
||||
where
|
||||
|
@ -76,7 +76,7 @@ where
|
|||
HandleFloatFn:
|
||||
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
|
||||
{
|
||||
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
|
||||
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
|
||||
|
||||
let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) {
|
||||
// Special handling for floats
|
||||
|
@ -85,20 +85,44 @@ where
|
|||
} else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) {
|
||||
let n = scalar.value.into_int_value();
|
||||
|
||||
if n.get_type().get_bit_width() <= target_int_dtype_llvm.get_bit_width() {
|
||||
ctx.builder.build_int_z_extend(n, target_int_dtype_llvm, "zext").unwrap()
|
||||
if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() {
|
||||
ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap()
|
||||
} else {
|
||||
ctx.builder.build_int_truncate(n, target_int_dtype_llvm, "trunc").unwrap()
|
||||
ctx.builder.build_int_truncate(n, ret_int_dtype_llvm, "trunc").unwrap()
|
||||
}
|
||||
} else {
|
||||
unsupported_type(ctx, [scalar.dtype]);
|
||||
};
|
||||
|
||||
assert_eq!(target_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
|
||||
ScalarObject { value: result.into(), dtype: target_int_dtype }
|
||||
assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
|
||||
ScalarObject { value: result.into(), dtype: ret_int_dtype }
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarObject<'ctx> {
|
||||
/// Convenience function. Assume this scalar has typechecker type float64, get its underlying LLVM value.
|
||||
///
|
||||
/// Panic if the type is wrong.
|
||||
pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> {
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
self.value.into_float_value() // self.value must be a FloatValue
|
||||
} else {
|
||||
panic!("not a float type")
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience function. Assume this scalar has typechecker type int32, get its underlying LLVM value.
|
||||
///
|
||||
/// Panic if the type is wrong.
|
||||
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) {
|
||||
let value = self.value.into_int_value();
|
||||
debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check
|
||||
value
|
||||
} else {
|
||||
panic!("not a float type")
|
||||
}
|
||||
}
|
||||
|
||||
/// Compare two scalars. Only int-to-int and float-to-float comparisons are allowed.
|
||||
/// Panic otherwise.
|
||||
pub fn compare<G: CodeGenerator + ?Sized>(
|
||||
|
@ -238,10 +262,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
|||
|
||||
/// Invoke NAC3's builtin `bool()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_bool<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
) -> Self {
|
||||
pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
// TODO: Why is the original code being so lax about i1 and i8 for the returned int type?
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) {
|
||||
self.value.into_int_value()
|
||||
|
@ -262,24 +283,47 @@ impl<'ctx> ScalarObject<'ctx> {
|
|||
ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `float()`.
|
||||
#[must_use]
|
||||
pub fn cast_to_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
let llvm_f64 = ctx.ctx.f64_type();
|
||||
|
||||
let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
self.value.into_float_value()
|
||||
} else if ctx
|
||||
.unifier
|
||||
.unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat())
|
||||
{
|
||||
let n = self.value.into_int_value();
|
||||
ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap()
|
||||
} else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) {
|
||||
let n = self.value.into_int_value();
|
||||
ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype]);
|
||||
};
|
||||
|
||||
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `round()`.
|
||||
#[must_use]
|
||||
pub fn round<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
target_int_dtype: Type,
|
||||
ret_int_dtype: Type,
|
||||
) -> Self {
|
||||
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
|
||||
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
|
||||
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.value.into_float_value();
|
||||
let n = llvm_intrinsics::call_float_round(ctx, n, None);
|
||||
ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "round").unwrap()
|
||||
ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap()
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype, target_int_dtype])
|
||||
unsupported_type(ctx, [self.dtype, ret_int_dtype])
|
||||
};
|
||||
ScalarObject { dtype: target_int_dtype, value: result.as_basic_value_enum() }
|
||||
ScalarObject { dtype: ret_int_dtype, value: result.as_basic_value_enum() }
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `np_round()`.
|
||||
|
@ -287,7 +331,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
|||
/// NOTE: `np.round()` has different behaviors than `round()` in terms of their result
|
||||
/// on "tie" cases and return type.
|
||||
#[must_use]
|
||||
pub fn np_round<G: CodeGenerator + ?Sized>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.value.into_float_value();
|
||||
llvm_intrinsics::call_float_rint(ctx, n, None)
|
||||
|
@ -298,7 +342,7 @@ impl<'ctx> ScalarObject<'ctx> {
|
|||
}
|
||||
|
||||
/// Invoke NAC3's builtin `min()` or `max()`.
|
||||
fn min_or_max_helper(
|
||||
pub fn min_or_max(
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: MinOrMax,
|
||||
a: Self,
|
||||
|
@ -335,15 +379,19 @@ impl<'ctx> ScalarObject<'ctx> {
|
|||
}
|
||||
|
||||
/// Invoke NAC3's builtin `floor()` or `ceil()`.
|
||||
///
|
||||
/// * `ret_int_dtype` - The type of int to return.
|
||||
///
|
||||
/// Takes in a float/int and returns an int of type `ret_int_dtype`
|
||||
#[must_use]
|
||||
pub fn floor_or_ceil<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
kind: FloorOrCeil,
|
||||
target_int_dtype: Type,
|
||||
ret_int_dtype: Type,
|
||||
) -> Self {
|
||||
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
|
||||
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
|
||||
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let function = match kind {
|
||||
|
@ -352,8 +400,45 @@ impl<'ctx> ScalarObject<'ctx> {
|
|||
};
|
||||
let n = self.value.into_float_value();
|
||||
let n = function(ctx, n, None);
|
||||
let n = ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "").unwrap();
|
||||
ScalarObject { dtype: target_int_dtype, value: n.as_basic_value_enum() }
|
||||
|
||||
let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap();
|
||||
ScalarObject { dtype: ret_int_dtype, value: n.as_basic_value_enum() }
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `np_floor()`/ `np_ceil()`.
|
||||
///
|
||||
/// Takes in a float/int and returns a float64 result.
|
||||
#[must_use]
|
||||
pub fn np_floor_or_ceil(&self, ctx: &mut CodeGenContext<'ctx, '_>, kind: FloorOrCeil) -> Self {
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let function = match kind {
|
||||
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
|
||||
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
|
||||
};
|
||||
let n = self.value.into_float_value();
|
||||
let n = function(ctx, n, None);
|
||||
ScalarObject { dtype: ctx.primitives.float, value: n.as_basic_value_enum() }
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `abs()`.
|
||||
pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
|
||||
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
|
||||
let n = self.value.into_float_value();
|
||||
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
|
||||
ScalarObject { value: n.into(), dtype: ctx.primitives.float }
|
||||
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
|
||||
let n = self.value.into_int_value();
|
||||
|
||||
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
|
||||
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
|
||||
|
||||
ScalarObject { value: n.into(), dtype: self.dtype }
|
||||
} else {
|
||||
unsupported_type(ctx, [self.dtype])
|
||||
}
|
||||
|
@ -361,12 +446,12 @@ impl<'ctx> ScalarObject<'ctx> {
|
|||
}
|
||||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// Helper function for NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
|
||||
/// Helper function to implement NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
|
||||
///
|
||||
/// Generate LLVM IR to find the extremum and index of the **first** extremum value.
|
||||
///
|
||||
/// Care has also been taken to make the error messages match that of NumPy.
|
||||
fn min_or_max_helper<G: CodeGenerator + ?Sized>(
|
||||
fn min_max_argmin_argmax_helper<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
|
@ -410,7 +495,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
|
||||
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
|
||||
|
||||
let new_extremum = ScalarObject::min_or_max_helper(ctx, kind, old_extremum, scalar);
|
||||
let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar);
|
||||
|
||||
// Check if new_extremum is more extreme than old_extremum.
|
||||
let update_index = ScalarObject::compare(
|
||||
|
@ -455,7 +540,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
MinOrMax::Max => "maximum",
|
||||
}
|
||||
);
|
||||
self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).0
|
||||
self.min_max_argmin_argmax_helper(generator, ctx, kind, &on_empty_err_msg).0
|
||||
}
|
||||
|
||||
/// Invoke NAC3's builtin `np_argmin()` or `np_argmax()`.
|
||||
|
@ -472,6 +557,6 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
MinOrMax::Max => "argmax",
|
||||
}
|
||||
);
|
||||
self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).1
|
||||
self.min_max_argmin_argmax_helper(generator, ctx, kind, &on_empty_err_msg).1
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,12 +15,12 @@ use super::scalar::ScalarOrNDArray;
|
|||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// TODO: Document me. Has complex behavior.
|
||||
/// and explain why `ret_dtype` has to be specified beforehand.
|
||||
pub fn broadcasting_starmap<'a, G, MappingFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ndarrays: &[Self],
|
||||
ret_dtype: Type,
|
||||
name: &str,
|
||||
mapping: MappingFn,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
|
@ -30,7 +30,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
&[ScalarObject<'ctx>],
|
||||
) -> Result<ScalarObject<'ctx>, String>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
|
@ -43,7 +43,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
ctx,
|
||||
ret_dtype,
|
||||
broadcast_result.ndims,
|
||||
name,
|
||||
"mapped_ndarray",
|
||||
);
|
||||
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
|
||||
mapped_ndarray.create_data(generator, ctx);
|
||||
|
@ -59,7 +59,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
let ret = mapping(generator, ctx, i, &elements)?;
|
||||
|
||||
let pret = mapped_ndarray.get_nth_pointer(generator, ctx, i, "pret");
|
||||
ctx.builder.build_store(pret, ret.value).unwrap();
|
||||
ctx.builder.build_store(pret, ret).unwrap();
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
|
@ -71,7 +71,6 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ret_dtype: Type,
|
||||
name: &str,
|
||||
mapping: Mapping,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
|
@ -88,23 +87,19 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
ctx,
|
||||
&[*self],
|
||||
ret_dtype,
|
||||
name,
|
||||
|generator, ctx, i, scalars| {
|
||||
let value = mapping(generator, ctx, i, scalars[0])?;
|
||||
Ok(ScalarObject { dtype: ret_dtype, value })
|
||||
},
|
||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx> ScalarOrNDArray<'ctx> {
|
||||
/// TODO: Document me. Has complex behavior.
|
||||
/// and explain why `ret_dtype` has to be specified beforehand.
|
||||
pub fn broadcasting_starmap<'a, G, MappingFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
inputs: &[Self],
|
||||
ret_dtype: Type,
|
||||
name: &str,
|
||||
mapping: MappingFn,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
|
@ -114,7 +109,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
|||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
&[ScalarObject<'ctx>],
|
||||
) -> Result<ScalarObject<'ctx>, String>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
|
@ -124,15 +119,40 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
|
|||
|
||||
if let Some(scalars) = all_scalars {
|
||||
let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index
|
||||
let scalar = mapping(generator, ctx, i, &scalars)?;
|
||||
let scalar =
|
||||
ScalarObject { value: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype };
|
||||
Ok(ScalarOrNDArray::Scalar(scalar))
|
||||
} else {
|
||||
// Promote all input to ndarrays and map through them.
|
||||
let inputs = inputs.iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec();
|
||||
let ndarray = NDArrayObject::broadcasting_starmap(
|
||||
generator, ctx, &inputs, ret_dtype, name, mapping,
|
||||
)?;
|
||||
let ndarray =
|
||||
NDArrayObject::broadcasting_starmap(generator, ctx, &inputs, ret_dtype, mapping)?;
|
||||
Ok(ScalarOrNDArray::NDArray(ndarray))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map<'a, G, Mapping>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ret_dtype: Type,
|
||||
mapping: Mapping,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
Mapping: FnOnce(
|
||||
&mut G,
|
||||
&mut CodeGenContext<'ctx, 'a>,
|
||||
Int<'ctx, SizeT>,
|
||||
ScalarObject<'ctx>,
|
||||
) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
ScalarOrNDArray::broadcasting_starmap(
|
||||
generator,
|
||||
ctx,
|
||||
&[*self],
|
||||
ret_dtype,
|
||||
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -229,6 +229,26 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
Self::alloca_uninitialized(generator, ctx, dtype, ndims, name)
|
||||
}
|
||||
|
||||
/// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents
|
||||
/// over.
|
||||
///
|
||||
/// The new ndarray will own its data and will be C-contiguous.
|
||||
pub fn make_clone<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
name: &str,
|
||||
) -> Self {
|
||||
let clone =
|
||||
NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, self.ndims, name);
|
||||
|
||||
let shape = self.value.gep(ctx, |f| f.shape).load(generator, ctx, "shape");
|
||||
clone.copy_shape_from_array(generator, ctx, shape);
|
||||
clone.create_data(generator, ctx);
|
||||
clone.copy_data_from(generator, ctx, *self);
|
||||
clone
|
||||
}
|
||||
|
||||
/// Get this ndarray's `ndims` as an LLVM constant.
|
||||
pub fn get_ndims<G: CodeGenerator + ?Sized>(
|
||||
&self,
|
||||
|
|
|
@ -13,22 +13,28 @@ use strum::IntoEnumIterator;
|
|||
|
||||
use crate::{
|
||||
codegen::{
|
||||
builtin_fns,
|
||||
builtin_fns::{self},
|
||||
classes::{ProxyValue, RangeValue},
|
||||
expr::destructure_range,
|
||||
irrt::*,
|
||||
extern_fns,
|
||||
irrt::{self, *},
|
||||
llvm_intrinsics,
|
||||
model::Int32,
|
||||
numpy::*,
|
||||
numpy_new::{self, gen_ndarray_transpose},
|
||||
stmt::exn_constructor,
|
||||
structure::ndarray::NDArrayObject,
|
||||
structure::ndarray::{
|
||||
functions::{FloorOrCeil, MinOrMax},
|
||||
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
|
||||
NDArrayObject,
|
||||
},
|
||||
},
|
||||
symbol_resolver::SymbolValue,
|
||||
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
|
||||
typecheck::{
|
||||
numpy::create_ndims,
|
||||
typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
|
||||
toplevel::{
|
||||
helper::PrimDef,
|
||||
numpy::{create_ndims, make_ndarray_ty},
|
||||
},
|
||||
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
@ -1053,16 +1059,34 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunInt32 => builtin_fns::call_int32,
|
||||
PrimDef::FunInt64 => builtin_fns::call_int64,
|
||||
PrimDef::FunUInt32 => builtin_fns::call_uint32,
|
||||
PrimDef::FunUInt64 => builtin_fns::call_uint64,
|
||||
PrimDef::FunFloat => builtin_fns::call_float,
|
||||
PrimDef::FunBool => builtin_fns::call_bool,
|
||||
let ret_dtype = match prim {
|
||||
PrimDef::FunInt32 => ctx.primitives.int32,
|
||||
PrimDef::FunInt64 => ctx.primitives.int64,
|
||||
PrimDef::FunUInt32 => ctx.primitives.uint32,
|
||||
PrimDef::FunUInt64 => ctx.primitives.uint64,
|
||||
PrimDef::FunFloat => ctx.primitives.float,
|
||||
PrimDef::FunBool => ctx.primitives.bool,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Some(func(generator, ctx, (arg_ty, arg))?))
|
||||
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
||||
generator,
|
||||
ctx,
|
||||
ret_dtype,
|
||||
|generator, ctx, _i, scalar| {
|
||||
let result = match prim {
|
||||
PrimDef::FunInt32 => scalar.cast_to_int32(generator, ctx),
|
||||
PrimDef::FunInt64 => scalar.cast_to_int64(generator, ctx),
|
||||
PrimDef::FunUInt32 => scalar.cast_to_uint32(generator, ctx),
|
||||
PrimDef::FunUInt64 => scalar.cast_to_uint64(generator, ctx),
|
||||
PrimDef::FunFloat => scalar.cast_to_float(ctx),
|
||||
PrimDef::FunBool => scalar.cast_to_bool(ctx),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(result.value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
@ -1113,20 +1137,23 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
|
||||
let ret_elem_ty = size_variant.of_int(&ctx.primitives);
|
||||
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
|
||||
let ret_int_dtype = size_variant.of_int(&ctx.primitives);
|
||||
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
||||
generator,
|
||||
ctx,
|
||||
ret_int_dtype,
|
||||
|generator, ctx, _i, scalar| {
|
||||
Ok(scalar.round(generator, ctx, ret_int_dtype).value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Build the functions `ceil()` and `floor()` and their 64 bit variants.
|
||||
fn build_ceil_floor_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
#[derive(Clone, Copy)]
|
||||
enum Kind {
|
||||
Floor,
|
||||
Ceil,
|
||||
}
|
||||
|
||||
debug_assert_prim_is_allowed(
|
||||
prim,
|
||||
&[PrimDef::FunFloor, PrimDef::FunFloor64, PrimDef::FunCeil, PrimDef::FunCeil64],
|
||||
|
@ -1134,10 +1161,10 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
|
||||
let (size_variant, kind) = {
|
||||
match prim {
|
||||
PrimDef::FunFloor => (SizeVariant::Bits32, Kind::Floor),
|
||||
PrimDef::FunFloor64 => (SizeVariant::Bits64, Kind::Floor),
|
||||
PrimDef::FunCeil => (SizeVariant::Bits32, Kind::Ceil),
|
||||
PrimDef::FunCeil64 => (SizeVariant::Bits64, Kind::Ceil),
|
||||
PrimDef::FunFloor => (SizeVariant::Bits32, FloorOrCeil::Floor),
|
||||
PrimDef::FunFloor64 => (SizeVariant::Bits64, FloorOrCeil::Floor),
|
||||
PrimDef::FunCeil => (SizeVariant::Bits32, FloorOrCeil::Ceil),
|
||||
PrimDef::FunCeil64 => (SizeVariant::Bits64, FloorOrCeil::Ceil),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
|
@ -1177,12 +1204,15 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
|
||||
let ret_elem_ty = size_variant.of_int(&ctx.primitives);
|
||||
let func = match kind {
|
||||
Kind::Ceil => builtin_fns::call_ceil,
|
||||
Kind::Floor => builtin_fns::call_floor,
|
||||
};
|
||||
Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
||||
generator,
|
||||
ctx,
|
||||
int_sized,
|
||||
|generator, ctx, _i, scalar| {
|
||||
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
@ -1546,12 +1576,22 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunNpCeil => builtin_fns::call_ceil,
|
||||
PrimDef::FunNpFloor => builtin_fns::call_floor,
|
||||
let kind = match prim {
|
||||
PrimDef::FunNpFloor => FloorOrCeil::Floor,
|
||||
PrimDef::FunNpCeil => FloorOrCeil::Ceil,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))
|
||||
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
move |_generator, ctx, _i, scalar| {
|
||||
let result = scalar.np_floor_or_ceil(ctx, kind);
|
||||
Ok(result.value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
@ -1569,7 +1609,17 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
Ok(Some(builtin_fns::call_numpy_round(generator, ctx, (arg_ty, arg))?))
|
||||
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
|_generator, ctx, _i, scalar| {
|
||||
let result = scalar.np_round(ctx);
|
||||
Ok(result.value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
@ -1678,16 +1728,21 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
move |ctx, _, fun, args, generator| {
|
||||
let m_ty = fun.0.args[0].ty;
|
||||
let n_ty = fun.0.args[1].ty;
|
||||
let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?;
|
||||
|
||||
let n_ty = fun.0.args[1].ty;
|
||||
let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunMin => builtin_fns::call_min,
|
||||
PrimDef::FunMax => builtin_fns::call_max,
|
||||
let kind = match prim {
|
||||
PrimDef::FunMin => MinOrMax::Min,
|
||||
PrimDef::FunMax => MinOrMax::Max,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Some(func(ctx, (m_ty, m_val), (n_ty, n_val))))
|
||||
|
||||
let m = ScalarObject { dtype: m_ty, value: m_val };
|
||||
let n = ScalarObject { dtype: n_ty, value: n_val };
|
||||
let result = ScalarObject::min_or_max(ctx, kind, m, n);
|
||||
Ok(Some(result.value))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
@ -1729,7 +1784,25 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
let a_ty = fun.0.args[0].ty;
|
||||
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_numpy_max_min(generator, ctx, (a_ty, a), prim.name())?))
|
||||
let a = split_scalar_or_ndarray(generator, ctx, a, a_ty).as_ndarray(generator, ctx);
|
||||
let result = match prim {
|
||||
PrimDef::FunNpArgmin => a
|
||||
.argmin_or_argmax(generator, ctx, MinOrMax::Min)
|
||||
.value
|
||||
.as_basic_value_enum(),
|
||||
PrimDef::FunNpArgmax => a
|
||||
.argmin_or_argmax(generator, ctx, MinOrMax::Max)
|
||||
.value
|
||||
.as_basic_value_enum(),
|
||||
PrimDef::FunNpMin => {
|
||||
a.min_or_max(generator, ctx, MinOrMax::Min).value.as_basic_value_enum()
|
||||
}
|
||||
PrimDef::FunNpMax => {
|
||||
a.min_or_max(generator, ctx, MinOrMax::Max).value.as_basic_value_enum()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Some(result))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
@ -1764,13 +1837,32 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunNpMinimum => builtin_fns::call_numpy_minimum,
|
||||
PrimDef::FunNpMaximum => builtin_fns::call_numpy_maximum,
|
||||
let kind = match prim {
|
||||
PrimDef::FunNpMinimum => MinOrMax::Min,
|
||||
PrimDef::FunNpMaximum => MinOrMax::Max,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(Some(func(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||
let x1 = split_scalar_or_ndarray(generator, ctx, x1_val, x1_ty);
|
||||
let x2 = split_scalar_or_ndarray(generator, ctx, x2_val, x2_ty);
|
||||
|
||||
// NOTE: x1.dtype() and x2.dtype() should be the same
|
||||
let common_ty = x1.dtype();
|
||||
|
||||
let result = ScalarOrNDArray::broadcasting_starmap(
|
||||
generator,
|
||||
ctx,
|
||||
&[x1, x2],
|
||||
common_ty,
|
||||
|_generator, ctx, _i, scalars| {
|
||||
let x1 = scalars[0];
|
||||
let x2 = scalars[1];
|
||||
|
||||
let result = ScalarObject::min_or_max(ctx, kind, x1, x2);
|
||||
Ok(result.value)
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
@ -1781,6 +1873,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
fn build_abs_function(&mut self) -> TopLevelDef {
|
||||
let prim = PrimDef::FunAbs;
|
||||
|
||||
let num_ty = self.num_ty; // To move into codegen_callback
|
||||
TopLevelDef::Function {
|
||||
name: prim.name().into(),
|
||||
simple_name: prim.simple_name().into(),
|
||||
|
@ -1798,11 +1891,17 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, _, fun, args, generator| {
|
||||
move |ctx, _, fun, args, generator| {
|
||||
let n_ty = fun.0.args[0].ty;
|
||||
let n_val = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
|
||||
|
||||
Ok(Some(builtin_fns::call_abs(generator, ctx, (n_ty, n_val))?))
|
||||
let result = split_scalar_or_ndarray(generator, ctx, n_val, n_ty).map(
|
||||
generator,
|
||||
ctx,
|
||||
num_ty.ty,
|
||||
|_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).value),
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
@ -1825,13 +1924,23 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
let x_ty = fun.0.args[0].ty;
|
||||
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunNpIsInf => builtin_fns::call_numpy_isinf,
|
||||
PrimDef::FunNpIsNan => builtin_fns::call_numpy_isnan,
|
||||
let function = match prim {
|
||||
PrimDef::FunNpIsInf => irrt::call_isnan,
|
||||
PrimDef::FunNpIsNan => irrt::call_isinf,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
|
||||
let result = split_scalar_or_ndarray(generator, ctx, x_val, x_ty).map(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.bool,
|
||||
|generator, ctx, _i, scalar| {
|
||||
let n = scalar.into_float64(ctx);
|
||||
let n = function(generator, ctx, n);
|
||||
Ok(n.as_basic_value_enum())
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
@ -1889,49 +1998,58 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunNpSin => builtin_fns::call_numpy_sin,
|
||||
PrimDef::FunNpCos => builtin_fns::call_numpy_cos,
|
||||
PrimDef::FunNpTan => builtin_fns::call_numpy_tan,
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg_val, arg_ty).map(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
|_generator, ctx, _i, scalar| {
|
||||
let n = scalar.into_float64(ctx);
|
||||
let n = match prim {
|
||||
PrimDef::FunNpSin => llvm_intrinsics::call_float_sin(ctx, n, None),
|
||||
PrimDef::FunNpCos => llvm_intrinsics::call_float_cos(ctx, n, None),
|
||||
PrimDef::FunNpTan => extern_fns::call_tan(ctx, n, None),
|
||||
|
||||
PrimDef::FunNpArcsin => builtin_fns::call_numpy_arcsin,
|
||||
PrimDef::FunNpArccos => builtin_fns::call_numpy_arccos,
|
||||
PrimDef::FunNpArctan => builtin_fns::call_numpy_arctan,
|
||||
PrimDef::FunNpArcsin => extern_fns::call_asin(ctx, n, None),
|
||||
PrimDef::FunNpArccos => extern_fns::call_acos(ctx, n, None),
|
||||
PrimDef::FunNpArctan => extern_fns::call_atan(ctx, n, None),
|
||||
|
||||
PrimDef::FunNpSinh => builtin_fns::call_numpy_sinh,
|
||||
PrimDef::FunNpCosh => builtin_fns::call_numpy_cosh,
|
||||
PrimDef::FunNpTanh => builtin_fns::call_numpy_tanh,
|
||||
PrimDef::FunNpSinh => extern_fns::call_sinh(ctx, n, None),
|
||||
PrimDef::FunNpCosh => extern_fns::call_cosh(ctx, n, None),
|
||||
PrimDef::FunNpTanh => extern_fns::call_tanh(ctx, n, None),
|
||||
|
||||
PrimDef::FunNpArcsinh => builtin_fns::call_numpy_arcsinh,
|
||||
PrimDef::FunNpArccosh => builtin_fns::call_numpy_arccosh,
|
||||
PrimDef::FunNpArctanh => builtin_fns::call_numpy_arctanh,
|
||||
PrimDef::FunNpArcsinh => extern_fns::call_asinh(ctx, n, None),
|
||||
PrimDef::FunNpArccosh => extern_fns::call_acosh(ctx, n, None),
|
||||
PrimDef::FunNpArctanh => extern_fns::call_atanh(ctx, n, None),
|
||||
|
||||
PrimDef::FunNpExp => builtin_fns::call_numpy_exp,
|
||||
PrimDef::FunNpExp2 => builtin_fns::call_numpy_exp2,
|
||||
PrimDef::FunNpExpm1 => builtin_fns::call_numpy_expm1,
|
||||
PrimDef::FunNpExp => llvm_intrinsics::call_float_exp(ctx, n, None),
|
||||
PrimDef::FunNpExp2 => llvm_intrinsics::call_float_exp2(ctx, n, None),
|
||||
PrimDef::FunNpExpm1 => extern_fns::call_expm1(ctx, n, None),
|
||||
|
||||
PrimDef::FunNpLog => builtin_fns::call_numpy_log,
|
||||
PrimDef::FunNpLog2 => builtin_fns::call_numpy_log2,
|
||||
PrimDef::FunNpLog10 => builtin_fns::call_numpy_log10,
|
||||
PrimDef::FunNpLog => llvm_intrinsics::call_float_log(ctx, n, None),
|
||||
PrimDef::FunNpLog2 => llvm_intrinsics::call_float_log2(ctx, n, None),
|
||||
PrimDef::FunNpLog10 => llvm_intrinsics::call_float_log10(ctx, n, None),
|
||||
|
||||
PrimDef::FunNpSqrt => builtin_fns::call_numpy_sqrt,
|
||||
PrimDef::FunNpCbrt => builtin_fns::call_numpy_cbrt,
|
||||
PrimDef::FunNpSqrt => llvm_intrinsics::call_float_sqrt(ctx, n, None),
|
||||
PrimDef::FunNpCbrt => extern_fns::call_cbrt(ctx, n, None),
|
||||
|
||||
PrimDef::FunNpFabs => builtin_fns::call_numpy_fabs,
|
||||
PrimDef::FunNpRint => builtin_fns::call_numpy_rint,
|
||||
PrimDef::FunNpFabs => llvm_intrinsics::call_float_fabs(ctx, n, None),
|
||||
PrimDef::FunNpRint => llvm_intrinsics::call_float_rint(ctx, n, None),
|
||||
|
||||
PrimDef::FunSpSpecErf => builtin_fns::call_scipy_special_erf,
|
||||
PrimDef::FunSpSpecErfc => builtin_fns::call_scipy_special_erfc,
|
||||
PrimDef::FunSpSpecErf => extern_fns::call_erf(ctx, n, None),
|
||||
PrimDef::FunSpSpecErfc => extern_fns::call_erfc(ctx, n, None),
|
||||
|
||||
PrimDef::FunSpSpecGamma => builtin_fns::call_scipy_special_gamma,
|
||||
PrimDef::FunSpSpecGammaln => builtin_fns::call_scipy_special_gammaln,
|
||||
PrimDef::FunSpSpecGamma => irrt::call_gamma(ctx, n),
|
||||
PrimDef::FunSpSpecGammaln => irrt::call_gammaln(ctx, n),
|
||||
|
||||
PrimDef::FunSpSpecJ0 => builtin_fns::call_scipy_special_j0,
|
||||
PrimDef::FunSpSpecJ1 => builtin_fns::call_scipy_special_j1,
|
||||
PrimDef::FunSpSpecJ0 => irrt::call_j0(ctx, n),
|
||||
PrimDef::FunSpSpecJ1 => extern_fns::call_j1(ctx, n, None),
|
||||
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(Some(func(generator, ctx, (arg_ty, arg_val))?))
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(n.as_basic_value_enum())
|
||||
},
|
||||
)?;
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
@ -1953,20 +2071,20 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
|
||||
let PrimitiveStore { float, int32, .. } = *self.primitives;
|
||||
|
||||
// The argument types of the two input arguments are controlled here.
|
||||
let (x1_ty, x2_ty) = match prim {
|
||||
// The argument types of the two input arguments + the return type
|
||||
let (x1_dtype, x2_dtype, ret_dtype) = match prim {
|
||||
PrimDef::FunNpArctan2
|
||||
| PrimDef::FunNpCopysign
|
||||
| PrimDef::FunNpFmax
|
||||
| PrimDef::FunNpFmin
|
||||
| PrimDef::FunNpHypot
|
||||
| PrimDef::FunNpNextAfter => (float, float),
|
||||
PrimDef::FunNpLdExp => (float, int32),
|
||||
| PrimDef::FunNpNextAfter => (float, float, float),
|
||||
PrimDef::FunNpLdExp => (float, int32, float),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let x1_ty = self.new_type_or_ndarray_ty(x1_ty);
|
||||
let x2_ty = self.new_type_or_ndarray_ty(x2_ty);
|
||||
let x1_ty = self.new_type_or_ndarray_ty(x1_dtype);
|
||||
let x2_ty = self.new_type_or_ndarray_ty(x2_dtype);
|
||||
|
||||
let param_ty = &[(x1_ty.ty, "x1"), (x2_ty.ty, "x2")];
|
||||
let ret_ty = self.unifier.get_fresh_var(None, None);
|
||||
|
@ -1990,21 +2108,46 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
|
||||
let func = match prim {
|
||||
PrimDef::FunNpArctan2 => builtin_fns::call_numpy_arctan2,
|
||||
PrimDef::FunNpCopysign => builtin_fns::call_numpy_copysign,
|
||||
PrimDef::FunNpFmax => builtin_fns::call_numpy_fmax,
|
||||
PrimDef::FunNpFmin => builtin_fns::call_numpy_fmin,
|
||||
PrimDef::FunNpLdExp => builtin_fns::call_numpy_ldexp,
|
||||
PrimDef::FunNpHypot => builtin_fns::call_numpy_hypot,
|
||||
PrimDef::FunNpNextAfter => builtin_fns::call_numpy_nextafter,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let x1 = split_scalar_or_ndarray(generator, ctx, x1_val, x1_ty);
|
||||
let x2 = split_scalar_or_ndarray(generator, ctx, x2_val, x2_ty);
|
||||
|
||||
Ok(Some(func(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||
let result = ScalarOrNDArray::broadcasting_starmap(
|
||||
generator,
|
||||
ctx,
|
||||
&[x1, x2],
|
||||
ret_dtype,
|
||||
|_generator, ctx, _i, scalars| {
|
||||
let x1 = scalars[0];
|
||||
let x2 = scalars[1];
|
||||
|
||||
let result = match prim {
|
||||
PrimDef::FunNpArctan2
|
||||
| PrimDef::FunNpCopysign
|
||||
| PrimDef::FunNpFmax
|
||||
| PrimDef::FunNpFmin
|
||||
| PrimDef::FunNpHypot
|
||||
| PrimDef::FunNpNextAfter => {
|
||||
let x1 = x1.into_float64(ctx);
|
||||
let x2 = x2.into_float64(ctx);
|
||||
extern_fns::call_atan2(ctx, x1, x2, None).as_basic_value_enum()
|
||||
}
|
||||
PrimDef::FunNpLdExp => {
|
||||
let x1 = x1.into_float64(ctx);
|
||||
let x2 = x2.into_int32(ctx);
|
||||
extern_fns::call_ldexp(ctx, x1, x2, None).as_basic_value_enum()
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(result)
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Some(result.to_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
symbol_resolver::SymbolValue,
|
||||
toplevel::helper::PrimDef,
|
||||
typecheck::{
|
||||
type_inferencer::PrimitiveStore,
|
||||
|
@ -83,3 +86,33 @@ pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarI
|
|||
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
|
||||
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
|
||||
}
|
||||
|
||||
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
|
||||
/// The `ndims` must only contain 1 value.
|
||||
#[must_use]
|
||||
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
|
||||
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
|
||||
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
|
||||
panic!("ndims_ty should be a TLiteral");
|
||||
};
|
||||
|
||||
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
|
||||
|
||||
let ndims = values[0].clone();
|
||||
u64::try_from(ndims).unwrap()
|
||||
}
|
||||
|
||||
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
|
||||
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
|
||||
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
|
||||
}
|
||||
|
||||
/// Return the ndims after broadcasting ndarrays of different ndims.
|
||||
///
|
||||
/// Panics if the input list is empty.
|
||||
pub fn get_broadcast_all_ndims<I>(ndims: I) -> u64
|
||||
where
|
||||
I: IntoIterator<Item = u64>,
|
||||
{
|
||||
ndims.into_iter().max().unwrap()
|
||||
}
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
mod function_check;
|
||||
pub mod magic_methods;
|
||||
pub mod numpy;
|
||||
pub mod type_error;
|
||||
pub mod type_inferencer;
|
||||
pub mod typedef;
|
||||
|
|
|
@ -1,33 +0,0 @@
|
|||
use crate::{symbol_resolver::SymbolValue, typecheck::typedef::TypeEnum};
|
||||
|
||||
use super::typedef::{Type, Unifier};
|
||||
|
||||
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
|
||||
/// The `ndims` must only contain 1 value.
|
||||
#[must_use]
|
||||
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
|
||||
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
|
||||
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
|
||||
panic!("ndims_ty should be a TLiteral");
|
||||
};
|
||||
|
||||
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
|
||||
|
||||
let ndims = values[0].clone();
|
||||
u64::try_from(ndims).unwrap()
|
||||
}
|
||||
|
||||
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
|
||||
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
|
||||
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
|
||||
}
|
||||
|
||||
/// Return the ndims after broadcasting ndarrays of different ndims.
|
||||
///
|
||||
/// Panics if the input list is empty.
|
||||
pub fn get_broadcast_all_ndims<I>(ndims: I) -> u64
|
||||
where
|
||||
I: IntoIterator<Item = u64>,
|
||||
{
|
||||
ndims.into_iter().max().unwrap()
|
||||
}
|
|
@ -11,12 +11,11 @@ use super::{
|
|||
RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap,
|
||||
},
|
||||
};
|
||||
use crate::typecheck::numpy::extract_ndims;
|
||||
use crate::{
|
||||
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||
toplevel::{
|
||||
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
numpy::{extract_ndims, make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
TopLevelContext, TopLevelDef,
|
||||
},
|
||||
typecheck::typedef::Mapping,
|
||||
|
|
Loading…
Reference in New Issue