forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: checkpoint 3

This commit is contained in:
lyken 2024-08-09 12:03:10 +08:00
parent 937b36dcfd
commit 858b4b9f3f
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
10 changed files with 453 additions and 1483 deletions

File diff suppressed because it is too large Load Diff

View File

@ -15,11 +15,11 @@ use crate::{
},
},
symbol_resolver::ValueEnum,
toplevel::{numpy::unpack_ndarray_var_tys, DefinitionId},
typecheck::{
numpy::extract_ndims,
typedef::{FunSignature, Type},
toplevel::{
numpy::{extract_ndims, unpack_ndarray_var_tys},
DefinitionId,
},
typecheck::typedef::{FunSignature, Type},
};
use super::{

View File

@ -68,7 +68,7 @@ fn cast_to_int_conversion<'ctx, 'a, G, HandleFloatFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
scalar: ScalarObject<'ctx>,
target_int_dtype: Type,
ret_int_dtype: Type,
handle_float: HandleFloatFn,
) -> ScalarObject<'ctx>
where
@ -76,7 +76,7 @@ where
HandleFloatFn:
FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, FloatValue<'ctx>) -> IntValue<'ctx>,
{
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
let result = if ctx.unifier.unioned(scalar.dtype, ctx.primitives.float) {
// Special handling for floats
@ -85,20 +85,44 @@ where
} else if ctx.unifier.unioned_any(scalar.dtype, int_like(ctx)) {
let n = scalar.value.into_int_value();
if n.get_type().get_bit_width() <= target_int_dtype_llvm.get_bit_width() {
ctx.builder.build_int_z_extend(n, target_int_dtype_llvm, "zext").unwrap()
if n.get_type().get_bit_width() <= ret_int_dtype_llvm.get_bit_width() {
ctx.builder.build_int_z_extend(n, ret_int_dtype_llvm, "zext").unwrap()
} else {
ctx.builder.build_int_truncate(n, target_int_dtype_llvm, "trunc").unwrap()
ctx.builder.build_int_truncate(n, ret_int_dtype_llvm, "trunc").unwrap()
}
} else {
unsupported_type(ctx, [scalar.dtype]);
};
assert_eq!(target_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
ScalarObject { value: result.into(), dtype: target_int_dtype }
assert_eq!(ret_int_dtype_llvm.get_bit_width(), result.get_type().get_bit_width()); // Sanity check
ScalarObject { value: result.into(), dtype: ret_int_dtype }
}
impl<'ctx> ScalarObject<'ctx> {
/// Convenience function. Assume this scalar has typechecker type float64, get its underlying LLVM value.
///
/// Panic if the type is wrong.
pub fn into_float64(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> FloatValue<'ctx> {
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
self.value.into_float_value() // self.value must be a FloatValue
} else {
panic!("not a float type")
}
}
/// Convenience function. Assume this scalar has typechecker type int32, get its underlying LLVM value.
///
/// Panic if the type is wrong.
pub fn into_int32(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
if ctx.unifier.unioned(self.dtype, ctx.primitives.int32) {
let value = self.value.into_int_value();
debug_assert_eq!(value.get_type().get_bit_width(), 32); // Sanity check
value
} else {
panic!("not a float type")
}
}
/// Compare two scalars. Only int-to-int and float-to-float comparisons are allowed.
/// Panic otherwise.
pub fn compare<G: CodeGenerator + ?Sized>(
@ -238,10 +262,7 @@ impl<'ctx> ScalarObject<'ctx> {
/// Invoke NAC3's builtin `bool()`.
#[must_use]
pub fn cast_to_bool<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
) -> Self {
pub fn cast_to_bool(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
// TODO: Why is the original code being so lax about i1 and i8 for the returned int type?
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.bool) {
self.value.into_int_value()
@ -262,24 +283,47 @@ impl<'ctx> ScalarObject<'ctx> {
ScalarObject { dtype: ctx.primitives.bool, value: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `float()`.
#[must_use]
pub fn cast_to_float(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
let llvm_f64 = ctx.ctx.f64_type();
let result: FloatValue<'_> = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
self.value.into_float_value()
} else if ctx
.unifier
.unioned_any(self.dtype, [signed_ints(ctx).as_slice(), &[ctx.primitives.bool]].concat())
{
let n = self.value.into_int_value();
ctx.builder.build_signed_int_to_float(n, llvm_f64, "sitofp").unwrap()
} else if ctx.unifier.unioned_any(self.dtype, unsigned_ints(ctx)) {
let n = self.value.into_int_value();
ctx.builder.build_unsigned_int_to_float(n, llvm_f64, "uitofp").unwrap()
} else {
unsupported_type(ctx, [self.dtype]);
};
ScalarObject { value: result.as_basic_value_enum(), dtype: ctx.primitives.float }
}
/// Invoke NAC3's builtin `round()`.
#[must_use]
pub fn round<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target_int_dtype: Type,
ret_int_dtype: Type,
) -> Self {
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_round(ctx, n, None);
ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "round").unwrap()
ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "round").unwrap()
} else {
unsupported_type(ctx, [self.dtype, target_int_dtype])
unsupported_type(ctx, [self.dtype, ret_int_dtype])
};
ScalarObject { dtype: target_int_dtype, value: result.as_basic_value_enum() }
ScalarObject { dtype: ret_int_dtype, value: result.as_basic_value_enum() }
}
/// Invoke NAC3's builtin `np_round()`.
@ -287,7 +331,7 @@ impl<'ctx> ScalarObject<'ctx> {
/// NOTE: `np.round()` has different behaviors than `round()` in terms of their result
/// on "tie" cases and return type.
#[must_use]
pub fn np_round<G: CodeGenerator + ?Sized>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
pub fn np_round(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
let result = if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
llvm_intrinsics::call_float_rint(ctx, n, None)
@ -298,7 +342,7 @@ impl<'ctx> ScalarObject<'ctx> {
}
/// Invoke NAC3's builtin `min()` or `max()`.
fn min_or_max_helper(
pub fn min_or_max(
ctx: &mut CodeGenContext<'ctx, '_>,
kind: MinOrMax,
a: Self,
@ -335,15 +379,19 @@ impl<'ctx> ScalarObject<'ctx> {
}
/// Invoke NAC3's builtin `floor()` or `ceil()`.
///
/// * `ret_int_dtype` - The type of int to return.
///
/// Takes in a float/int and returns an int of type `ret_int_dtype`
#[must_use]
pub fn floor_or_ceil<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
kind: FloorOrCeil,
target_int_dtype: Type,
ret_int_dtype: Type,
) -> Self {
let target_int_dtype_llvm = ctx.get_llvm_type(generator, target_int_dtype).into_int_type();
let ret_int_dtype_llvm = ctx.get_llvm_type(generator, ret_int_dtype).into_int_type();
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let function = match kind {
@ -352,8 +400,45 @@ impl<'ctx> ScalarObject<'ctx> {
};
let n = self.value.into_float_value();
let n = function(ctx, n, None);
let n = ctx.builder.build_float_to_signed_int(n, target_int_dtype_llvm, "").unwrap();
ScalarObject { dtype: target_int_dtype, value: n.as_basic_value_enum() }
let n = ctx.builder.build_float_to_signed_int(n, ret_int_dtype_llvm, "").unwrap();
ScalarObject { dtype: ret_int_dtype, value: n.as_basic_value_enum() }
} else {
unsupported_type(ctx, [self.dtype])
}
}
/// Invoke NAC3's builtin `np_floor()`/ `np_ceil()`.
///
/// Takes in a float/int and returns a float64 result.
#[must_use]
pub fn np_floor_or_ceil(&self, ctx: &mut CodeGenContext<'ctx, '_>, kind: FloorOrCeil) -> Self {
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let function = match kind {
FloorOrCeil::Floor => llvm_intrinsics::call_float_floor,
FloorOrCeil::Ceil => llvm_intrinsics::call_float_ceil,
};
let n = self.value.into_float_value();
let n = function(ctx, n, None);
ScalarObject { dtype: ctx.primitives.float, value: n.as_basic_value_enum() }
} else {
unsupported_type(ctx, [self.dtype])
}
}
/// Invoke NAC3's builtin `abs()`.
pub fn abs(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Self {
if ctx.unifier.unioned(self.dtype, ctx.primitives.float) {
let n = self.value.into_float_value();
let n = llvm_intrinsics::call_float_fabs(ctx, n, Some("abs"));
ScalarObject { value: n.into(), dtype: ctx.primitives.float }
} else if ctx.unifier.unioned_any(self.dtype, ints(ctx)) {
let n = self.value.into_int_value();
let is_poisoned = ctx.ctx.bool_type().const_zero(); // is_poisoned = false
let n = llvm_intrinsics::call_int_abs(ctx, n, is_poisoned, Some("abs"));
ScalarObject { value: n.into(), dtype: self.dtype }
} else {
unsupported_type(ctx, [self.dtype])
}
@ -361,12 +446,12 @@ impl<'ctx> ScalarObject<'ctx> {
}
impl<'ctx> NDArrayObject<'ctx> {
/// Helper function for NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
/// Helper function to implement NAC3's builtin `np_min()`, `np_max()`, `np_argmin()`, and `np_argmax()`.
///
/// Generate LLVM IR to find the extremum and index of the **first** extremum value.
///
/// Care has also been taken to make the error messages match that of NumPy.
fn min_or_max_helper<G: CodeGenerator + ?Sized>(
fn min_max_argmin_argmax_helper<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -410,7 +495,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let old_extremum = ctx.builder.build_load(pextremum, "current_extremum").unwrap();
let old_extremum = ScalarObject { dtype: self.dtype, value: old_extremum };
let new_extremum = ScalarObject::min_or_max_helper(ctx, kind, old_extremum, scalar);
let new_extremum = ScalarObject::min_or_max(ctx, kind, old_extremum, scalar);
// Check if new_extremum is more extreme than old_extremum.
let update_index = ScalarObject::compare(
@ -455,7 +540,7 @@ impl<'ctx> NDArrayObject<'ctx> {
MinOrMax::Max => "maximum",
}
);
self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).0
self.min_max_argmin_argmax_helper(generator, ctx, kind, &on_empty_err_msg).0
}
/// Invoke NAC3's builtin `np_argmin()` or `np_argmax()`.
@ -472,6 +557,6 @@ impl<'ctx> NDArrayObject<'ctx> {
MinOrMax::Max => "argmax",
}
);
self.min_or_max_helper(generator, ctx, kind, &on_empty_err_msg).1
self.min_max_argmin_argmax_helper(generator, ctx, kind, &on_empty_err_msg).1
}
}

View File

@ -15,12 +15,12 @@ use super::scalar::ScalarOrNDArray;
impl<'ctx> NDArrayObject<'ctx> {
/// TODO: Document me. Has complex behavior.
/// and explain why `ret_dtype` has to be specified beforehand.
pub fn broadcasting_starmap<'a, G, MappingFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ndarrays: &[Self],
ret_dtype: Type,
name: &str,
mapping: MappingFn,
) -> Result<Self, String>
where
@ -30,7 +30,7 @@ impl<'ctx> NDArrayObject<'ctx> {
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
&[ScalarObject<'ctx>],
) -> Result<ScalarObject<'ctx>, String>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
let sizet_model = IntModel(SizeT);
@ -43,7 +43,7 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx,
ret_dtype,
broadcast_result.ndims,
name,
"mapped_ndarray",
);
mapped_ndarray.copy_shape_from_array(generator, ctx, broadcast_result.shape);
mapped_ndarray.create_data(generator, ctx);
@ -59,7 +59,7 @@ impl<'ctx> NDArrayObject<'ctx> {
let ret = mapping(generator, ctx, i, &elements)?;
let pret = mapped_ndarray.get_nth_pointer(generator, ctx, i, "pret");
ctx.builder.build_store(pret, ret.value).unwrap();
ctx.builder.build_store(pret, ret).unwrap();
Ok(())
})?;
@ -71,7 +71,6 @@ impl<'ctx> NDArrayObject<'ctx> {
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ret_dtype: Type,
name: &str,
mapping: Mapping,
) -> Result<Self, String>
where
@ -88,23 +87,19 @@ impl<'ctx> NDArrayObject<'ctx> {
ctx,
&[*self],
ret_dtype,
name,
|generator, ctx, i, scalars| {
let value = mapping(generator, ctx, i, scalars[0])?;
Ok(ScalarObject { dtype: ret_dtype, value })
},
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
)
}
}
impl<'ctx> ScalarOrNDArray<'ctx> {
/// TODO: Document me. Has complex behavior.
/// and explain why `ret_dtype` has to be specified beforehand.
pub fn broadcasting_starmap<'a, G, MappingFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
inputs: &[Self],
ret_dtype: Type,
name: &str,
mapping: MappingFn,
) -> Result<Self, String>
where
@ -114,7 +109,7 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
&[ScalarObject<'ctx>],
) -> Result<ScalarObject<'ctx>, String>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
let sizet_model = IntModel(SizeT);
@ -124,15 +119,40 @@ impl<'ctx> ScalarOrNDArray<'ctx> {
if let Some(scalars) = all_scalars {
let i = sizet_model.const_0(generator, ctx.ctx); // Pass 0 as the index
let scalar = mapping(generator, ctx, i, &scalars)?;
let scalar =
ScalarObject { value: mapping(generator, ctx, i, &scalars)?, dtype: ret_dtype };
Ok(ScalarOrNDArray::Scalar(scalar))
} else {
// Promote all input to ndarrays and map through them.
let inputs = inputs.iter().map(|input| input.as_ndarray(generator, ctx)).collect_vec();
let ndarray = NDArrayObject::broadcasting_starmap(
generator, ctx, &inputs, ret_dtype, name, mapping,
)?;
let ndarray =
NDArrayObject::broadcasting_starmap(generator, ctx, &inputs, ret_dtype, mapping)?;
Ok(ScalarOrNDArray::NDArray(ndarray))
}
}
pub fn map<'a, G, Mapping>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, 'a>,
ret_dtype: Type,
mapping: Mapping,
) -> Result<Self, String>
where
G: CodeGenerator + ?Sized,
Mapping: FnOnce(
&mut G,
&mut CodeGenContext<'ctx, 'a>,
Int<'ctx, SizeT>,
ScalarObject<'ctx>,
) -> Result<BasicValueEnum<'ctx>, String>,
{
ScalarOrNDArray::broadcasting_starmap(
generator,
ctx,
&[*self],
ret_dtype,
|generator, ctx, i, scalars| mapping(generator, ctx, i, scalars[0]),
)
}
}

View File

@ -229,6 +229,26 @@ impl<'ctx> NDArrayObject<'ctx> {
Self::alloca_uninitialized(generator, ctx, dtype, ndims, name)
}
/// Clone this ndaarray - Allocate a new ndarray with the same shape as this ndarray and copy the contents
/// over.
///
/// The new ndarray will own its data and will be C-contiguous.
pub fn make_clone<G: CodeGenerator + ?Sized>(
&self,
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
name: &str,
) -> Self {
let clone =
NDArrayObject::alloca_uninitialized(generator, ctx, self.dtype, self.ndims, name);
let shape = self.value.gep(ctx, |f| f.shape).load(generator, ctx, "shape");
clone.copy_shape_from_array(generator, ctx, shape);
clone.create_data(generator, ctx);
clone.copy_data_from(generator, ctx, *self);
clone
}
/// Get this ndarray's `ndims` as an LLVM constant.
pub fn get_ndims<G: CodeGenerator + ?Sized>(
&self,

View File

@ -13,22 +13,28 @@ use strum::IntoEnumIterator;
use crate::{
codegen::{
builtin_fns,
builtin_fns::{self},
classes::{ProxyValue, RangeValue},
expr::destructure_range,
irrt::*,
extern_fns,
irrt::{self, *},
llvm_intrinsics,
model::Int32,
numpy::*,
numpy_new::{self, gen_ndarray_transpose},
stmt::exn_constructor,
structure::ndarray::NDArrayObject,
structure::ndarray::{
functions::{FloorOrCeil, MinOrMax},
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
NDArrayObject,
},
},
symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
typecheck::{
numpy::create_ndims,
typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
toplevel::{
helper::PrimDef,
numpy::{create_ndims, make_ndarray_ty},
},
typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap},
};
use super::*;
@ -1053,16 +1059,34 @@ impl<'a> BuiltinBuilder<'a> {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let func = match prim {
PrimDef::FunInt32 => builtin_fns::call_int32,
PrimDef::FunInt64 => builtin_fns::call_int64,
PrimDef::FunUInt32 => builtin_fns::call_uint32,
PrimDef::FunUInt64 => builtin_fns::call_uint64,
PrimDef::FunFloat => builtin_fns::call_float,
PrimDef::FunBool => builtin_fns::call_bool,
let ret_dtype = match prim {
PrimDef::FunInt32 => ctx.primitives.int32,
PrimDef::FunInt64 => ctx.primitives.int64,
PrimDef::FunUInt32 => ctx.primitives.uint32,
PrimDef::FunUInt64 => ctx.primitives.uint64,
PrimDef::FunFloat => ctx.primitives.float,
PrimDef::FunBool => ctx.primitives.bool,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (arg_ty, arg))?))
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
generator,
ctx,
ret_dtype,
|generator, ctx, _i, scalar| {
let result = match prim {
PrimDef::FunInt32 => scalar.cast_to_int32(generator, ctx),
PrimDef::FunInt64 => scalar.cast_to_int64(generator, ctx),
PrimDef::FunUInt32 => scalar.cast_to_uint32(generator, ctx),
PrimDef::FunUInt64 => scalar.cast_to_uint64(generator, ctx),
PrimDef::FunFloat => scalar.cast_to_float(ctx),
PrimDef::FunBool => scalar.cast_to_bool(ctx),
_ => unreachable!(),
};
Ok(result.value)
},
)?;
Ok(Some(result.to_basic_value_enum()))
},
)))),
loc: None,
@ -1113,20 +1137,23 @@ impl<'a> BuiltinBuilder<'a> {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let ret_elem_ty = size_variant.of_int(&ctx.primitives);
Ok(Some(builtin_fns::call_round(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
let ret_int_dtype = size_variant.of_int(&ctx.primitives);
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
generator,
ctx,
ret_int_dtype,
|generator, ctx, _i, scalar| {
Ok(scalar.round(generator, ctx, ret_int_dtype).value)
},
)?;
Ok(Some(result.to_basic_value_enum()))
}),
)
}
/// Build the functions `ceil()` and `floor()` and their 64 bit variants.
fn build_ceil_floor_function(&mut self, prim: PrimDef) -> TopLevelDef {
#[derive(Clone, Copy)]
enum Kind {
Floor,
Ceil,
}
debug_assert_prim_is_allowed(
prim,
&[PrimDef::FunFloor, PrimDef::FunFloor64, PrimDef::FunCeil, PrimDef::FunCeil64],
@ -1134,10 +1161,10 @@ impl<'a> BuiltinBuilder<'a> {
let (size_variant, kind) = {
match prim {
PrimDef::FunFloor => (SizeVariant::Bits32, Kind::Floor),
PrimDef::FunFloor64 => (SizeVariant::Bits64, Kind::Floor),
PrimDef::FunCeil => (SizeVariant::Bits32, Kind::Ceil),
PrimDef::FunCeil64 => (SizeVariant::Bits64, Kind::Ceil),
PrimDef::FunFloor => (SizeVariant::Bits32, FloorOrCeil::Floor),
PrimDef::FunFloor64 => (SizeVariant::Bits64, FloorOrCeil::Floor),
PrimDef::FunCeil => (SizeVariant::Bits32, FloorOrCeil::Ceil),
PrimDef::FunCeil64 => (SizeVariant::Bits64, FloorOrCeil::Ceil),
_ => unreachable!(),
}
};
@ -1177,12 +1204,15 @@ impl<'a> BuiltinBuilder<'a> {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let ret_elem_ty = size_variant.of_int(&ctx.primitives);
let func = match kind {
Kind::Ceil => builtin_fns::call_ceil,
Kind::Floor => builtin_fns::call_floor,
};
Ok(Some(func(generator, ctx, (arg_ty, arg), ret_elem_ty)?))
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
generator,
ctx,
int_sized,
|generator, ctx, _i, scalar| {
Ok(scalar.floor_or_ceil(generator, ctx, kind, int_sized).value)
},
)?;
Ok(Some(result.to_basic_value_enum()))
}),
)
}
@ -1546,12 +1576,22 @@ impl<'a> BuiltinBuilder<'a> {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let func = match prim {
PrimDef::FunNpCeil => builtin_fns::call_ceil,
PrimDef::FunNpFloor => builtin_fns::call_floor,
let kind = match prim {
PrimDef::FunNpFloor => FloorOrCeil::Floor,
PrimDef::FunNpCeil => FloorOrCeil::Ceil,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (arg_ty, arg), ctx.primitives.float)?))
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
generator,
ctx,
ctx.primitives.float,
move |_generator, ctx, _i, scalar| {
let result = scalar.np_floor_or_ceil(ctx, kind);
Ok(result.value)
},
)?;
Ok(Some(result.to_basic_value_enum()))
}),
)
}
@ -1569,7 +1609,17 @@ impl<'a> BuiltinBuilder<'a> {
Box::new(|ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(Some(builtin_fns::call_numpy_round(generator, ctx, (arg_ty, arg))?))
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
generator,
ctx,
ctx.primitives.float,
|_generator, ctx, _i, scalar| {
let result = scalar.np_round(ctx);
Ok(result.value)
},
)?;
Ok(Some(result.to_basic_value_enum()))
}),
)
}
@ -1678,16 +1728,21 @@ impl<'a> BuiltinBuilder<'a> {
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
move |ctx, _, fun, args, generator| {
let m_ty = fun.0.args[0].ty;
let n_ty = fun.0.args[1].ty;
let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator, m_ty)?;
let n_ty = fun.0.args[1].ty;
let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
let func = match prim {
PrimDef::FunMin => builtin_fns::call_min,
PrimDef::FunMax => builtin_fns::call_max,
let kind = match prim {
PrimDef::FunMin => MinOrMax::Min,
PrimDef::FunMax => MinOrMax::Max,
_ => unreachable!(),
};
Ok(Some(func(ctx, (m_ty, m_val), (n_ty, n_val))))
let m = ScalarObject { dtype: m_ty, value: m_val };
let n = ScalarObject { dtype: n_ty, value: n_val };
let result = ScalarObject::min_or_max(ctx, kind, m, n);
Ok(Some(result.value))
},
)))),
loc: None,
@ -1729,7 +1784,25 @@ impl<'a> BuiltinBuilder<'a> {
let a_ty = fun.0.args[0].ty;
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
Ok(Some(builtin_fns::call_numpy_max_min(generator, ctx, (a_ty, a), prim.name())?))
let a = split_scalar_or_ndarray(generator, ctx, a, a_ty).as_ndarray(generator, ctx);
let result = match prim {
PrimDef::FunNpArgmin => a
.argmin_or_argmax(generator, ctx, MinOrMax::Min)
.value
.as_basic_value_enum(),
PrimDef::FunNpArgmax => a
.argmin_or_argmax(generator, ctx, MinOrMax::Max)
.value
.as_basic_value_enum(),
PrimDef::FunNpMin => {
a.min_or_max(generator, ctx, MinOrMax::Min).value.as_basic_value_enum()
}
PrimDef::FunNpMax => {
a.min_or_max(generator, ctx, MinOrMax::Max).value.as_basic_value_enum()
}
_ => unreachable!(),
};
Ok(Some(result))
}),
)
}
@ -1764,13 +1837,32 @@ impl<'a> BuiltinBuilder<'a> {
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
let func = match prim {
PrimDef::FunNpMinimum => builtin_fns::call_numpy_minimum,
PrimDef::FunNpMaximum => builtin_fns::call_numpy_maximum,
let kind = match prim {
PrimDef::FunNpMinimum => MinOrMax::Min,
PrimDef::FunNpMaximum => MinOrMax::Max,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
let x1 = split_scalar_or_ndarray(generator, ctx, x1_val, x1_ty);
let x2 = split_scalar_or_ndarray(generator, ctx, x2_val, x2_ty);
// NOTE: x1.dtype() and x2.dtype() should be the same
let common_ty = x1.dtype();
let result = ScalarOrNDArray::broadcasting_starmap(
generator,
ctx,
&[x1, x2],
common_ty,
|_generator, ctx, _i, scalars| {
let x1 = scalars[0];
let x2 = scalars[1];
let result = ScalarObject::min_or_max(ctx, kind, x1, x2);
Ok(result.value)
},
)?;
Ok(Some(result.to_basic_value_enum()))
},
)))),
loc: None,
@ -1781,6 +1873,7 @@ impl<'a> BuiltinBuilder<'a> {
fn build_abs_function(&mut self) -> TopLevelDef {
let prim = PrimDef::FunAbs;
let num_ty = self.num_ty; // To move into codegen_callback
TopLevelDef::Function {
name: prim.name().into(),
simple_name: prim.simple_name().into(),
@ -1798,11 +1891,17 @@ impl<'a> BuiltinBuilder<'a> {
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| {
move |ctx, _, fun, args, generator| {
let n_ty = fun.0.args[0].ty;
let n_val = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
Ok(Some(builtin_fns::call_abs(generator, ctx, (n_ty, n_val))?))
let result = split_scalar_or_ndarray(generator, ctx, n_val, n_ty).map(
generator,
ctx,
num_ty.ty,
|_generator, ctx, _i, scalar| Ok(scalar.abs(ctx).value),
)?;
Ok(Some(result.to_basic_value_enum()))
},
)))),
loc: None,
@ -1825,13 +1924,23 @@ impl<'a> BuiltinBuilder<'a> {
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
let func = match prim {
PrimDef::FunNpIsInf => builtin_fns::call_numpy_isinf,
PrimDef::FunNpIsNan => builtin_fns::call_numpy_isnan,
let function = match prim {
PrimDef::FunNpIsInf => irrt::call_isnan,
PrimDef::FunNpIsNan => irrt::call_isinf,
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x_ty, x_val))?))
let result = split_scalar_or_ndarray(generator, ctx, x_val, x_ty).map(
generator,
ctx,
ctx.primitives.bool,
|generator, ctx, _i, scalar| {
let n = scalar.into_float64(ctx);
let n = function(generator, ctx, n);
Ok(n.as_basic_value_enum())
},
)?;
Ok(Some(result.to_basic_value_enum()))
}),
)
}
@ -1889,49 +1998,58 @@ impl<'a> BuiltinBuilder<'a> {
let arg_ty = fun.0.args[0].ty;
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let func = match prim {
PrimDef::FunNpSin => builtin_fns::call_numpy_sin,
PrimDef::FunNpCos => builtin_fns::call_numpy_cos,
PrimDef::FunNpTan => builtin_fns::call_numpy_tan,
let result = split_scalar_or_ndarray(generator, ctx, arg_val, arg_ty).map(
generator,
ctx,
ctx.primitives.float,
|_generator, ctx, _i, scalar| {
let n = scalar.into_float64(ctx);
let n = match prim {
PrimDef::FunNpSin => llvm_intrinsics::call_float_sin(ctx, n, None),
PrimDef::FunNpCos => llvm_intrinsics::call_float_cos(ctx, n, None),
PrimDef::FunNpTan => extern_fns::call_tan(ctx, n, None),
PrimDef::FunNpArcsin => builtin_fns::call_numpy_arcsin,
PrimDef::FunNpArccos => builtin_fns::call_numpy_arccos,
PrimDef::FunNpArctan => builtin_fns::call_numpy_arctan,
PrimDef::FunNpArcsin => extern_fns::call_asin(ctx, n, None),
PrimDef::FunNpArccos => extern_fns::call_acos(ctx, n, None),
PrimDef::FunNpArctan => extern_fns::call_atan(ctx, n, None),
PrimDef::FunNpSinh => builtin_fns::call_numpy_sinh,
PrimDef::FunNpCosh => builtin_fns::call_numpy_cosh,
PrimDef::FunNpTanh => builtin_fns::call_numpy_tanh,
PrimDef::FunNpSinh => extern_fns::call_sinh(ctx, n, None),
PrimDef::FunNpCosh => extern_fns::call_cosh(ctx, n, None),
PrimDef::FunNpTanh => extern_fns::call_tanh(ctx, n, None),
PrimDef::FunNpArcsinh => builtin_fns::call_numpy_arcsinh,
PrimDef::FunNpArccosh => builtin_fns::call_numpy_arccosh,
PrimDef::FunNpArctanh => builtin_fns::call_numpy_arctanh,
PrimDef::FunNpArcsinh => extern_fns::call_asinh(ctx, n, None),
PrimDef::FunNpArccosh => extern_fns::call_acosh(ctx, n, None),
PrimDef::FunNpArctanh => extern_fns::call_atanh(ctx, n, None),
PrimDef::FunNpExp => builtin_fns::call_numpy_exp,
PrimDef::FunNpExp2 => builtin_fns::call_numpy_exp2,
PrimDef::FunNpExpm1 => builtin_fns::call_numpy_expm1,
PrimDef::FunNpExp => llvm_intrinsics::call_float_exp(ctx, n, None),
PrimDef::FunNpExp2 => llvm_intrinsics::call_float_exp2(ctx, n, None),
PrimDef::FunNpExpm1 => extern_fns::call_expm1(ctx, n, None),
PrimDef::FunNpLog => builtin_fns::call_numpy_log,
PrimDef::FunNpLog2 => builtin_fns::call_numpy_log2,
PrimDef::FunNpLog10 => builtin_fns::call_numpy_log10,
PrimDef::FunNpLog => llvm_intrinsics::call_float_log(ctx, n, None),
PrimDef::FunNpLog2 => llvm_intrinsics::call_float_log2(ctx, n, None),
PrimDef::FunNpLog10 => llvm_intrinsics::call_float_log10(ctx, n, None),
PrimDef::FunNpSqrt => builtin_fns::call_numpy_sqrt,
PrimDef::FunNpCbrt => builtin_fns::call_numpy_cbrt,
PrimDef::FunNpSqrt => llvm_intrinsics::call_float_sqrt(ctx, n, None),
PrimDef::FunNpCbrt => extern_fns::call_cbrt(ctx, n, None),
PrimDef::FunNpFabs => builtin_fns::call_numpy_fabs,
PrimDef::FunNpRint => builtin_fns::call_numpy_rint,
PrimDef::FunNpFabs => llvm_intrinsics::call_float_fabs(ctx, n, None),
PrimDef::FunNpRint => llvm_intrinsics::call_float_rint(ctx, n, None),
PrimDef::FunSpSpecErf => builtin_fns::call_scipy_special_erf,
PrimDef::FunSpSpecErfc => builtin_fns::call_scipy_special_erfc,
PrimDef::FunSpSpecErf => extern_fns::call_erf(ctx, n, None),
PrimDef::FunSpSpecErfc => extern_fns::call_erfc(ctx, n, None),
PrimDef::FunSpSpecGamma => builtin_fns::call_scipy_special_gamma,
PrimDef::FunSpSpecGammaln => builtin_fns::call_scipy_special_gammaln,
PrimDef::FunSpSpecGamma => irrt::call_gamma(ctx, n),
PrimDef::FunSpSpecGammaln => irrt::call_gammaln(ctx, n),
PrimDef::FunSpSpecJ0 => builtin_fns::call_scipy_special_j0,
PrimDef::FunSpSpecJ1 => builtin_fns::call_scipy_special_j1,
PrimDef::FunSpSpecJ0 => irrt::call_j0(ctx, n),
PrimDef::FunSpSpecJ1 => extern_fns::call_j1(ctx, n, None),
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (arg_ty, arg_val))?))
Ok(n.as_basic_value_enum())
},
)?;
Ok(Some(result.to_basic_value_enum()))
}),
)
}
@ -1953,20 +2071,20 @@ impl<'a> BuiltinBuilder<'a> {
let PrimitiveStore { float, int32, .. } = *self.primitives;
// The argument types of the two input arguments are controlled here.
let (x1_ty, x2_ty) = match prim {
// The argument types of the two input arguments + the return type
let (x1_dtype, x2_dtype, ret_dtype) = match prim {
PrimDef::FunNpArctan2
| PrimDef::FunNpCopysign
| PrimDef::FunNpFmax
| PrimDef::FunNpFmin
| PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => (float, float),
PrimDef::FunNpLdExp => (float, int32),
| PrimDef::FunNpNextAfter => (float, float, float),
PrimDef::FunNpLdExp => (float, int32, float),
_ => unreachable!(),
};
let x1_ty = self.new_type_or_ndarray_ty(x1_ty);
let x2_ty = self.new_type_or_ndarray_ty(x2_ty);
let x1_ty = self.new_type_or_ndarray_ty(x1_dtype);
let x2_ty = self.new_type_or_ndarray_ty(x2_dtype);
let param_ty = &[(x1_ty.ty, "x1"), (x2_ty.ty, "x2")];
let ret_ty = self.unifier.get_fresh_var(None, None);
@ -1990,21 +2108,46 @@ impl<'a> BuiltinBuilder<'a> {
move |ctx, _, fun, args, generator| {
let x1_ty = fun.0.args[0].ty;
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
let x2_ty = fun.0.args[1].ty;
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
let func = match prim {
PrimDef::FunNpArctan2 => builtin_fns::call_numpy_arctan2,
PrimDef::FunNpCopysign => builtin_fns::call_numpy_copysign,
PrimDef::FunNpFmax => builtin_fns::call_numpy_fmax,
PrimDef::FunNpFmin => builtin_fns::call_numpy_fmin,
PrimDef::FunNpLdExp => builtin_fns::call_numpy_ldexp,
PrimDef::FunNpHypot => builtin_fns::call_numpy_hypot,
PrimDef::FunNpNextAfter => builtin_fns::call_numpy_nextafter,
let x1 = split_scalar_or_ndarray(generator, ctx, x1_val, x1_ty);
let x2 = split_scalar_or_ndarray(generator, ctx, x2_val, x2_ty);
let result = ScalarOrNDArray::broadcasting_starmap(
generator,
ctx,
&[x1, x2],
ret_dtype,
|_generator, ctx, _i, scalars| {
let x1 = scalars[0];
let x2 = scalars[1];
let result = match prim {
PrimDef::FunNpArctan2
| PrimDef::FunNpCopysign
| PrimDef::FunNpFmax
| PrimDef::FunNpFmin
| PrimDef::FunNpHypot
| PrimDef::FunNpNextAfter => {
let x1 = x1.into_float64(ctx);
let x2 = x2.into_float64(ctx);
extern_fns::call_atan2(ctx, x1, x2, None).as_basic_value_enum()
}
PrimDef::FunNpLdExp => {
let x1 = x1.into_float64(ctx);
let x2 = x2.into_int32(ctx);
extern_fns::call_ldexp(ctx, x1, x2, None).as_basic_value_enum()
}
_ => unreachable!(),
};
Ok(Some(func(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
Ok(result)
},
)?;
Ok(Some(result.to_basic_value_enum()))
},
)))),
loc: None,

View File

@ -1,4 +1,7 @@
use std::sync::Arc;
use crate::{
symbol_resolver::SymbolValue,
toplevel::helper::PrimDef,
typecheck::{
type_inferencer::PrimitiveStore,
@ -83,3 +86,33 @@ pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (TypeVarI
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
}
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
/// The `ndims` must only contain 1 value.
#[must_use]
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
panic!("ndims_ty should be a TLiteral");
};
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
let ndims = values[0].clone();
u64::try_from(ndims).unwrap()
}
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
}
/// Return the ndims after broadcasting ndarrays of different ndims.
///
/// Panics if the input list is empty.
pub fn get_broadcast_all_ndims<I>(ndims: I) -> u64
where
I: IntoIterator<Item = u64>,
{
ndims.into_iter().max().unwrap()
}

View File

@ -1,6 +1,5 @@
mod function_check;
pub mod magic_methods;
pub mod numpy;
pub mod type_error;
pub mod type_inferencer;
pub mod typedef;

View File

@ -1,33 +0,0 @@
use crate::{symbol_resolver::SymbolValue, typecheck::typedef::TypeEnum};
use super::typedef::{Type, Unifier};
/// Extract an ndarray's `ndims` [type][`Type`] in `u64`. Panic if not possible.
/// The `ndims` must only contain 1 value.
#[must_use]
pub fn extract_ndims(unifier: &Unifier, ndims_ty: Type) -> u64 {
let ndims_ty_enum = unifier.get_ty_immutable(ndims_ty);
let TypeEnum::TLiteral { values, .. } = &*ndims_ty_enum else {
panic!("ndims_ty should be a TLiteral");
};
assert_eq!(values.len(), 1, "ndims_ty TLiteral should only contain 1 value");
let ndims = values[0].clone();
u64::try_from(ndims).unwrap()
}
/// Return an ndarray's `ndims` as a typechecker [`Type`] from its `u64` value.
pub fn create_ndims(unifier: &mut Unifier, ndims: u64) -> Type {
unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None)
}
/// Return the ndims after broadcasting ndarrays of different ndims.
///
/// Panics if the input list is empty.
pub fn get_broadcast_all_ndims<I>(ndims: I) -> u64
where
I: IntoIterator<Item = u64>,
{
ndims.into_iter().max().unwrap()
}

View File

@ -11,12 +11,11 @@ use super::{
RecordField, RecordKey, Type, TypeEnum, TypeVar, Unifier, VarMap,
},
};
use crate::typecheck::numpy::extract_ndims;
use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
numpy::{extract_ndims, make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelContext, TopLevelDef,
},
typecheck::typedef::Mapping,