forked from M-Labs/nac3
WIP: core/ndstrides: AnyObject + TupleObject
This commit is contained in:
parent
febe78b6a4
commit
2fbe981701
|
@ -1,15 +1,14 @@
|
|||
use inkwell::types::BasicTypeEnum;
|
||||
use inkwell::values::{BasicValue, BasicValueEnum, IntValue, PointerValue};
|
||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||
use inkwell::IntPredicate;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::codegen::classes::{
|
||||
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||
UntypedArrayLikeAccessor,
|
||||
};
|
||||
use crate::codegen::expr::destructure_range;
|
||||
use crate::codegen::irrt::calculate_len_for_slice_range;
|
||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||
use crate::codegen::{extern_fns, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||
|
|
|
@ -12,7 +12,10 @@ use crate::{
|
|||
call_memcpy_generic,
|
||||
},
|
||||
need_sret, numpy,
|
||||
object::ndarray::{scalar::split_scalar_or_ndarray, NDArrayObject, NDArrayOut},
|
||||
object::{
|
||||
ndarray::{scalar::split_scalar_or_ndarray, NDArrayObject, NDArrayOut},
|
||||
AnyObject,
|
||||
},
|
||||
stmt::{
|
||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||
gen_var,
|
||||
|
@ -1537,10 +1540,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||
{
|
||||
let left =
|
||||
split_scalar_or_ndarray(generator, ctx, left_val, ty1).as_ndarray(generator, ctx);
|
||||
let right =
|
||||
split_scalar_or_ndarray(generator, ctx, right_val, ty2).as_ndarray(generator, ctx);
|
||||
let left = AnyObject { ty: ty1, value: left_val };
|
||||
let right = AnyObject { ty: ty1, value: right_val };
|
||||
|
||||
let left = split_scalar_or_ndarray(generator, ctx, left).as_ndarray(generator, ctx);
|
||||
let right = split_scalar_or_ndarray(generator, ctx, right).as_ndarray(generator, ctx);
|
||||
|
||||
debug_assert!(ctx.unifier.unioned(left.dtype, right.dtype)); // Typechecker ensures this.
|
||||
|
||||
|
@ -2860,8 +2864,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||
|
||||
let ndarray_ty = value.custom.unwrap();
|
||||
let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||
let ndarray =
|
||||
NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
|
||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||
|
||||
let indexes = gen_ndarray_subscript_ndindexes(generator, ctx, slice)?;
|
||||
let result = ndarray
|
||||
|
|
|
@ -13,7 +13,10 @@ use crate::{
|
|||
use super::{
|
||||
irrt::call_nac3_ndarray_util_assert_shape_no_negative,
|
||||
model::*,
|
||||
object::ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
|
||||
object::{
|
||||
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
|
||||
AnyObject,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
};
|
||||
|
||||
|
@ -84,13 +87,12 @@ fn create_empty_ndarray<'ctx, G>(
|
|||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray_ty: Type,
|
||||
shape: BasicValueEnum<'ctx>,
|
||||
shape_ty: Type,
|
||||
shape: AnyObject<'ctx>,
|
||||
) -> NDArrayObject<'ctx>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
{
|
||||
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
|
||||
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||
|
||||
let ndarray =
|
||||
NDArrayObject::alloca_uninitialized_of_type(generator, ctx, ndarray_ty, "ndarray");
|
||||
|
@ -120,10 +122,11 @@ pub fn gen_ndarray_empty<'ctx>(
|
|||
// Parse arguments
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||
|
||||
// Implementation
|
||||
let ndarray_ty = fun.0.ret;
|
||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
|
||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
|
||||
|
||||
Ok(ndarray.instance.value.as_basic_value_enum())
|
||||
}
|
||||
|
@ -142,10 +145,11 @@ pub fn gen_ndarray_zeros<'ctx>(
|
|||
// Parse arguments
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||
|
||||
// Implementation
|
||||
let ndarray_ty = fun.0.ret;
|
||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
|
||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
|
||||
|
||||
let fill_value = ndarray_zero_value(generator, ctx, ndarray.dtype);
|
||||
ndarray.fill(generator, ctx, fill_value);
|
||||
|
@ -167,10 +171,11 @@ pub fn gen_ndarray_ones<'ctx>(
|
|||
// Parse arguments
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||
|
||||
// Implementation
|
||||
let ndarray_ty = fun.0.ret;
|
||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
|
||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
|
||||
|
||||
let fill_value = ndarray_one_value(generator, ctx, ndarray.dtype);
|
||||
ndarray.fill(generator, ctx, fill_value);
|
||||
|
@ -192,6 +197,7 @@ pub fn gen_ndarray_full<'ctx>(
|
|||
// Parse argument #1 shape
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||
|
||||
// Parse argument #2 fill_value
|
||||
let fill_value_ty = fun.0.args[1].ty;
|
||||
|
@ -199,7 +205,7 @@ pub fn gen_ndarray_full<'ctx>(
|
|||
|
||||
// Implementation
|
||||
let ndarray_ty = fun.0.ret;
|
||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
|
||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
|
||||
|
||||
ndarray.fill(generator, ctx, fill_value);
|
||||
|
||||
|
@ -220,10 +226,12 @@ pub fn gen_ndarray_broadcast_to<'ctx>(
|
|||
// Parse argument #1 input
|
||||
let input_ty = fun.0.args[0].ty;
|
||||
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
|
||||
let input = AnyObject { ty: input_ty, value: input };
|
||||
|
||||
// Parse argument #2 shape
|
||||
let shape_ty = fun.0.args[1].ty;
|
||||
let shape = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||
|
||||
// Define models
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
@ -234,11 +242,10 @@ pub fn gen_ndarray_broadcast_to<'ctx>(
|
|||
let broadcast_ndims = extract_ndims(&ctx.unifier, broadcast_ndims_ty);
|
||||
|
||||
// Process `input`
|
||||
let in_ndarray =
|
||||
split_scalar_or_ndarray(generator, ctx, input, input_ty).as_ndarray(generator, ctx);
|
||||
let in_ndarray = split_scalar_or_ndarray(generator, ctx, input).as_ndarray(generator, ctx);
|
||||
|
||||
// Process `shape`
|
||||
let (_, broadcast_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
|
||||
let (_, broadcast_shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||
// NOTE: shape.size should equal to `broadcasted_ndims`.
|
||||
let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
|
||||
call_nac3_ndarray_util_assert_shape_no_negative(
|
||||
|
@ -269,23 +276,24 @@ pub fn gen_ndarray_reshape<'ctx>(
|
|||
// Parse argument #1 input
|
||||
let input_ty = fun.0.args[0].ty;
|
||||
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
|
||||
let input = AnyObject { ty: input_ty, value: input };
|
||||
|
||||
// Parse argument #2 shape
|
||||
let shape_ty = fun.0.args[1].ty;
|
||||
let shape = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||
|
||||
// Extract reshaped_ndims
|
||||
let (_, reshaped_ndims_ty) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||
let reshaped_ndims = extract_ndims(&ctx.unifier, reshaped_ndims_ty);
|
||||
|
||||
// Process `input`
|
||||
let in_ndarray =
|
||||
split_scalar_or_ndarray(generator, ctx, input, input_ty).as_ndarray(generator, ctx);
|
||||
let in_ndarray = split_scalar_or_ndarray(generator, ctx, input).as_ndarray(generator, ctx);
|
||||
|
||||
// Process the shape input from user and resolve negative indices.
|
||||
// The resulting `new_shape`'s size should be equal to reshaped_ndims.
|
||||
// This is ensured by the typechecker.
|
||||
let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
|
||||
let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||
let reshaped_ndarray = in_ndarray.reshape_or_copy(generator, ctx, reshaped_ndims, new_shape);
|
||||
|
||||
Ok(reshaped_ndarray.instance.value.as_basic_value_enum())
|
||||
|
@ -354,8 +362,8 @@ pub fn gen_ndarray_size<'ctx>(
|
|||
|
||||
let ndarray_ty = fun.0.args[0].ty;
|
||||
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||
|
||||
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
|
||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||
|
||||
let size = ndarray.size(generator, ctx).truncate(generator, ctx, Int32, "size");
|
||||
Ok(size.value.as_basic_value_enum())
|
||||
|
@ -375,14 +383,15 @@ pub fn gen_ndarray_shape<'ctx>(
|
|||
// Parse argument #1 ndarray
|
||||
let ndarray_ty = fun.0.args[0].ty;
|
||||
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||
|
||||
// Define models
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
// Process ndarray
|
||||
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||
|
||||
let mut items = Vec::with_capacity(ndarray.ndims as usize);
|
||||
let mut objects = Vec::with_capacity(ndarray.ndims as usize);
|
||||
|
||||
for i in 0..ndarray.ndims {
|
||||
let i = sizet_model.constant(generator, ctx.ctx, i);
|
||||
|
@ -392,10 +401,11 @@ pub fn gen_ndarray_shape<'ctx>(
|
|||
.ix(generator, ctx, i.value, "dim");
|
||||
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
|
||||
|
||||
items.push((dim.value.as_basic_value_enum(), ctx.primitives.int32));
|
||||
objects
|
||||
.push(AnyObject { ty: ctx.primitives.int32, value: dim.value.as_basic_value_enum() });
|
||||
}
|
||||
|
||||
let shape = TupleObject::create(generator, ctx, items, "shape");
|
||||
let shape = TupleObject::create(generator, ctx, objects, "shape");
|
||||
Ok(shape.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
|
@ -407,7 +417,7 @@ pub fn gen_ndarray_strides<'ctx>(
|
|||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
// TODO: This function looks exactly like `gen_ndarray_shapes`, code duplication?
|
||||
// TODO: Code duplication: This function looks exactly like `gen_ndarray_shapes`.
|
||||
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
@ -415,14 +425,15 @@ pub fn gen_ndarray_strides<'ctx>(
|
|||
// Parse argument #1 ndarray
|
||||
let ndarray_ty = fun.0.args[0].ty;
|
||||
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||
|
||||
// Define models
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
// Process ndarray
|
||||
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||
|
||||
let mut items = Vec::with_capacity(ndarray.ndims as usize);
|
||||
let mut objects = Vec::with_capacity(ndarray.ndims as usize);
|
||||
|
||||
for i in 0..ndarray.ndims {
|
||||
let i = sizet_model.constant(generator, ctx.ctx, i);
|
||||
|
@ -432,10 +443,11 @@ pub fn gen_ndarray_strides<'ctx>(
|
|||
.ix(generator, ctx, i.value, "dim");
|
||||
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
|
||||
|
||||
items.push((dim.value.as_basic_value_enum(), ctx.primitives.int32));
|
||||
objects
|
||||
.push(AnyObject { ty: ctx.primitives.int32, value: dim.value.as_basic_value_enum() });
|
||||
}
|
||||
|
||||
let strides = TupleObject::create(generator, ctx, items, "strides");
|
||||
let strides = TupleObject::create(generator, ctx, objects, "strides");
|
||||
Ok(strides.value.as_basic_value_enum())
|
||||
}
|
||||
|
||||
|
@ -458,20 +470,23 @@ pub fn gen_ndarray_transpose<'ctx>(
|
|||
// Parse argument #1 ndarray
|
||||
let ndarray_ty = fun.0.args[0].ty;
|
||||
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||
|
||||
// Implementation
|
||||
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||
|
||||
let has_axes = args.len() >= 2;
|
||||
let transposed_ndarray = if has_axes {
|
||||
// Parse argument #2 axes
|
||||
let in_axes_ty = fun.0.args[1].ty;
|
||||
let in_axes = args[1].1.clone().to_basic_value_enum(ctx, generator, in_axes_ty)?;
|
||||
let in_axes = AnyObject { ty: in_axes_ty, value: in_axes };
|
||||
|
||||
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes, in_axes_ty);
|
||||
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes);
|
||||
|
||||
ndarray.transpose(generator, ctx, Some(axes))
|
||||
} else {
|
||||
// axes is not given
|
||||
ndarray.transpose(generator, ctx, None)
|
||||
};
|
||||
|
||||
|
@ -488,8 +503,9 @@ pub fn gen_ndarray_array<'ctx>(
|
|||
assert!(obj.is_none());
|
||||
assert!(matches!(args.len(), 1..=3));
|
||||
|
||||
let obj_ty = fun.0.args[0].ty;
|
||||
let obj_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, obj_ty)?;
|
||||
let object_ty = fun.0.args[0].ty;
|
||||
let object = args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?;
|
||||
let object = AnyObject { ty: object_ty, value: object };
|
||||
|
||||
let copy_arg = if let Some(arg) =
|
||||
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
||||
|
@ -513,7 +529,7 @@ pub fn gen_ndarray_array<'ctx>(
|
|||
let copy = IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8
|
||||
let copy = copy.truncate(generator, ctx, Bool, "copy_bool");
|
||||
|
||||
let ndarray = NDArrayObject::from_np_array(generator, ctx, obj_arg, obj_ty, copy);
|
||||
let ndarray = NDArrayObject::from_np_array(generator, ctx, object, copy);
|
||||
debug_assert!(ndarray.ndims <= output_ndims); // Sanity check on `ndims`
|
||||
|
||||
let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims);
|
||||
|
|
|
@ -5,6 +5,8 @@ use crate::{
|
|||
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
|
||||
};
|
||||
|
||||
use super::AnyObject;
|
||||
|
||||
/// A NAC3 Python List object.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ListObject<'ctx> {
|
||||
|
@ -15,21 +17,20 @@ pub struct ListObject<'ctx> {
|
|||
|
||||
impl<'ctx> ListObject<'ctx> {
|
||||
/// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`].
|
||||
pub fn from_value_and_type<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
|
||||
pub fn from_object<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
list_val: V,
|
||||
list_type: Type,
|
||||
object: AnyObject<'ctx>,
|
||||
) -> Self {
|
||||
// Check typechecker type and extract `item_type`
|
||||
let item_type = match &*ctx.unifier.get_ty(list_type) {
|
||||
let item_type = match &*ctx.unifier.get_ty(object.ty) {
|
||||
TypeEnum::TObj { obj_id, params, .. }
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
iter_type_vars(params).next().unwrap().ty // Extract `item_type`
|
||||
}
|
||||
_ => {
|
||||
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(list_type))
|
||||
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(object.ty))
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -37,7 +38,7 @@ impl<'ctx> ListObject<'ctx> {
|
|||
let plist_model = PtrModel(StructModel(List { item: item_model }));
|
||||
|
||||
// Create object
|
||||
let value = plist_model.check_value(generator, ctx.ctx, list_val).unwrap();
|
||||
let value = plist_model.check_value(generator, ctx.ctx, object.value).unwrap();
|
||||
ListObject { item_type, instance: value }
|
||||
}
|
||||
|
||||
|
|
|
@ -1,3 +1,13 @@
|
|||
use inkwell::values::BasicValueEnum;
|
||||
|
||||
use crate::typecheck::typedef::Type;
|
||||
|
||||
pub mod list;
|
||||
pub mod ndarray;
|
||||
pub mod tuple;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct AnyObject<'ctx> {
|
||||
pub ty: Type,
|
||||
pub value: BasicValueEnum<'ctx>,
|
||||
}
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
use inkwell::values::BasicValueEnum;
|
||||
|
||||
use super::NDArrayObject;
|
||||
use crate::{
|
||||
codegen::{
|
||||
irrt::{call_nac3_array_set_and_validate_list_shape, call_nac3_array_write_list_to_array},
|
||||
model::*,
|
||||
object::list::ListObject,
|
||||
object::{list::ListObject, AnyObject},
|
||||
stmt::gen_if_else_expr_callback,
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
|
@ -26,7 +24,7 @@ fn get_list_object_dtype_and_ndims<'ctx>(
|
|||
}
|
||||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
fn from_np_array_list_copy<G: CodeGenerator + ?Sized>(
|
||||
fn from_np_array_list_copy_impl<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
list: ListObject<'ctx>,
|
||||
|
@ -59,7 +57,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
ndarray
|
||||
}
|
||||
|
||||
fn from_np_array_list_try_no_copy<G: CodeGenerator + ?Sized>(
|
||||
fn from_np_array_list_try_no_copy_impl<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
list: ListObject<'ctx>,
|
||||
|
@ -101,11 +99,11 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
ndarray
|
||||
} else {
|
||||
// `list` is nested, it is impossible to not copy.
|
||||
NDArrayObject::from_np_array_list_copy(generator, ctx, list)
|
||||
NDArrayObject::from_np_array_list_copy_impl(generator, ctx, list)
|
||||
}
|
||||
}
|
||||
|
||||
fn from_np_array_list<G: CodeGenerator + ?Sized>(
|
||||
fn from_np_array_list_impl<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
list: ListObject<'ctx>,
|
||||
|
@ -118,11 +116,12 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
ctx,
|
||||
|_generator, _ctx| Ok(copy.value),
|
||||
|generator, ctx| {
|
||||
let ndarray = NDArrayObject::from_np_array_list_copy(generator, ctx, list);
|
||||
let ndarray = NDArrayObject::from_np_array_list_copy_impl(generator, ctx, list);
|
||||
Ok(Some(ndarray.instance.value))
|
||||
},
|
||||
|generator, ctx| {
|
||||
let ndarray = NDArrayObject::from_np_array_list_try_no_copy(generator, ctx, list);
|
||||
let ndarray =
|
||||
NDArrayObject::from_np_array_list_try_no_copy_impl(generator, ctx, list);
|
||||
Ok(Some(ndarray.instance.value))
|
||||
},
|
||||
)
|
||||
|
@ -132,7 +131,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
NDArrayObject::from_value_and_unpacked_types(generator, ctx, ndarray, dtype, ndims)
|
||||
}
|
||||
|
||||
fn from_np_array_ndarray<G: CodeGenerator + ?Sized>(
|
||||
pub fn from_np_array_ndarray_impl<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
ndarray: NDArrayObject<'ctx>,
|
||||
|
@ -166,24 +165,23 @@ impl<'ctx> NDArrayObject<'ctx> {
|
|||
pub fn from_np_array<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
object: BasicValueEnum<'ctx>,
|
||||
object_ty: Type,
|
||||
object: AnyObject<'ctx>,
|
||||
copy: Int<'ctx, Bool>,
|
||||
) -> Self {
|
||||
match &*ctx.unifier.get_ty(object_ty) {
|
||||
match &*ctx.unifier.get_ty(object.ty) {
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
let list = ListObject::from_value_and_type(generator, ctx, object, object_ty);
|
||||
NDArrayObject::from_np_array_list(generator, ctx, list, copy)
|
||||
let list = ListObject::from_object(generator, ctx, object);
|
||||
NDArrayObject::from_np_array_list_impl(generator, ctx, list, copy)
|
||||
}
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, object, object_ty);
|
||||
NDArrayObject::from_np_array_ndarray(generator, ctx, ndarray, copy)
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, object);
|
||||
NDArrayObject::from_np_array_ndarray_impl(generator, ctx, ndarray, copy)
|
||||
}
|
||||
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object_ty)), // Typechecker ensures this
|
||||
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object.ty)), // Typechecker ensures this
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,6 +35,8 @@ use inkwell::{
|
|||
use scalar::{ScalarObject, ScalarOrNDArray};
|
||||
use util::{call_memcpy_model, gen_for_model_auto};
|
||||
|
||||
use super::AnyObject;
|
||||
|
||||
/// A NAC3 Python ndarray object.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct NDArrayObject<'ctx> {
|
||||
|
@ -45,18 +47,17 @@ pub struct NDArrayObject<'ctx> {
|
|||
|
||||
impl<'ctx> NDArrayObject<'ctx> {
|
||||
/// Create an [`NDArrayObject`] from an LLVM value and its typechecker [`Type`].
|
||||
pub fn from_value_and_type<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
|
||||
pub fn from_object<G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
value: V,
|
||||
ty: Type,
|
||||
object: AnyObject<'ctx>,
|
||||
) -> Self {
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty);
|
||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||
Self::from_value_and_unpacked_types(generator, ctx, value, dtype, ndims)
|
||||
Self::from_value_and_unpacked_types(generator, ctx, object.value, dtype, ndims)
|
||||
}
|
||||
|
||||
/// Like [`NDArrayObject::from_value_and_type`] but you directly supply the ndarray's
|
||||
/// Like [`NDArrayObject::from_object`] but you directly supply the ndarray's
|
||||
/// `dtype` and `ndims`.
|
||||
pub fn from_value_and_unpacked_types<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use inkwell::values::{BasicValue, BasicValueEnum};
|
||||
|
||||
use crate::{
|
||||
codegen::{model::*, CodeGenContext, CodeGenerator},
|
||||
codegen::{model::*, object::AnyObject, CodeGenContext, CodeGenerator},
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
|
||||
|
@ -120,23 +120,22 @@ impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Split an [`BasicValueEnum<'ctx>`] into a [`ScalarOrNDArray`] depending
|
||||
/// Split an [`AnyObject`] into a [`ScalarOrNDArray`] depending
|
||||
/// on its [`Type`].
|
||||
pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
input: BasicValueEnum<'ctx>,
|
||||
input_ty: Type,
|
||||
object: AnyObject<'ctx>,
|
||||
) -> ScalarOrNDArray<'ctx> {
|
||||
match &*ctx.unifier.get_ty(input_ty) {
|
||||
match &*ctx.unifier.get_ty(object.ty) {
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, input, input_ty);
|
||||
let ndarray = NDArrayObject::from_object(generator, ctx, object);
|
||||
ScalarOrNDArray::NDArray(ndarray)
|
||||
}
|
||||
_ => {
|
||||
let scalar = ScalarObject { dtype: input_ty, instance: input };
|
||||
let scalar = ScalarObject { dtype: object.ty, instance: object.value };
|
||||
ScalarOrNDArray::Scalar(scalar)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
use inkwell::values::BasicValueEnum;
|
||||
use util::gen_for_model_auto;
|
||||
|
||||
use crate::{
|
||||
codegen::{model::*, object::list::ListObject, CodeGenContext, CodeGenerator},
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
codegen::{
|
||||
model::*,
|
||||
object::{list::ListObject, tuple::TupleObject, AnyObject},
|
||||
CodeGenContext, CodeGenerator,
|
||||
},
|
||||
typecheck::typedef::TypeEnum,
|
||||
};
|
||||
|
||||
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
|
||||
|
@ -20,8 +23,7 @@ use crate::{
|
|||
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
input_sequence: BasicValueEnum<'ctx>,
|
||||
input_sequence_ty: Type,
|
||||
input_sequence: AnyObject<'ctx>,
|
||||
) -> (Int<'ctx, SizeT>, Ptr<'ctx, IntModel<SizeT>>) {
|
||||
let sizet_model = IntModel(SizeT);
|
||||
|
||||
|
@ -29,15 +31,14 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
|||
let one = sizet_model.const_1(generator, ctx.ctx);
|
||||
|
||||
// The result `list` to return.
|
||||
match &*ctx.unifier.get_ty(input_sequence_ty) {
|
||||
match &*ctx.unifier.get_ty(input_sequence.ty) {
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||
|
||||
// Check `input_sequence`
|
||||
let input_sequence =
|
||||
ListObject::from_value_and_type(generator, ctx, input_sequence, input_sequence_ty);
|
||||
let input_sequence = ListObject::from_object(generator, ctx, input_sequence);
|
||||
|
||||
let len = input_sequence.instance.gep(ctx, |f| f.len).load(generator, ctx, "len");
|
||||
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
|
||||
|
@ -66,20 +67,17 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
|||
}
|
||||
TypeEnum::TTuple { ty: tuple_types, .. } => {
|
||||
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||
let input_sequence = input_sequence.into_struct_value(); // A tuple is a struct
|
||||
|
||||
let len_int = tuple_types.len();
|
||||
let input_sequence = TupleObject::from_object(ctx, input_sequence);
|
||||
|
||||
let len_int = input_sequence.len();
|
||||
let len = sizet_model.constant(generator, ctx.ctx, len_int as u64);
|
||||
|
||||
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
|
||||
|
||||
for i in 0..len_int {
|
||||
// Get the i-th element off of the tuple and load it into `result`.
|
||||
let int = ctx
|
||||
.builder
|
||||
.build_extract_value(input_sequence, i as u32, "int")
|
||||
.unwrap()
|
||||
.into_int_value();
|
||||
let int = input_sequence.get(ctx, i, "dim").value.into_int_value();
|
||||
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
|
||||
|
||||
let offset = sizet_model.constant(generator, ctx.ctx, i as u64);
|
||||
|
@ -92,7 +90,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
|||
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||
let input_int = input_sequence.into_int_value();
|
||||
let input_int = input_sequence.value.into_int_value();
|
||||
|
||||
let len = sizet_model.const_1(generator, ctx.ctx);
|
||||
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
|
||||
|
@ -106,7 +104,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
|||
}
|
||||
_ => panic!(
|
||||
"encountered unknown sequence type: {}",
|
||||
ctx.unifier.stringify(input_sequence_ty)
|
||||
ctx.unifier.stringify(input_sequence.ty)
|
||||
),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,34 +1,68 @@
|
|||
use inkwell::values::{BasicValueEnum, StructValue};
|
||||
use core::panic;
|
||||
|
||||
use inkwell::values::StructValue;
|
||||
use itertools::Itertools;
|
||||
|
||||
use crate::{
|
||||
codegen::{CodeGenContext, CodeGenerator},
|
||||
typecheck::typedef::Type,
|
||||
typecheck::typedef::{Type, TypeEnum},
|
||||
};
|
||||
|
||||
use super::AnyObject;
|
||||
|
||||
/// A NAC3 tuple object.
|
||||
///
|
||||
/// NOTE: This struct has no copy trait.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TupleObject<'ctx> {
|
||||
/// The type of the tuple.
|
||||
pub tys: Vec<Type>,
|
||||
/// The underlying LLVM value of this tuple.
|
||||
pub value: StructValue<'ctx>,
|
||||
}
|
||||
|
||||
impl<'ctx> TupleObject<'ctx> {
|
||||
// NOTE: There is no Model abstraction for Tuples. Everything has to be done raw with Inkwell.
|
||||
|
||||
pub fn from_object(ctx: &mut CodeGenContext<'ctx, '_>, object: AnyObject<'ctx>) -> Self {
|
||||
// TODO: Keep `is_vararg_ctx` from TTuple?
|
||||
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty(object.ty) else {
|
||||
panic!(
|
||||
"Expected type to be a TypeEnum::TTuple, got {}",
|
||||
ctx.unifier.stringify(object.ty)
|
||||
);
|
||||
};
|
||||
|
||||
let value = object.value.into_struct_value();
|
||||
if value.get_type().count_fields() as usize != tys.len() {
|
||||
panic!(
|
||||
"Tuple type has {} item(s), but the LLVM struct value has {} field(s)",
|
||||
tys.len(),
|
||||
value.get_type().count_fields()
|
||||
);
|
||||
}
|
||||
|
||||
TupleObject { tys: tys.clone(), value }
|
||||
}
|
||||
|
||||
/// Convenience function. Create a [`TupleObject`] from an iterator of objects.
|
||||
pub fn create<I, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
items: I,
|
||||
objects: I,
|
||||
name: &str,
|
||||
) -> Self
|
||||
where
|
||||
I: IntoIterator<Item = (BasicValueEnum<'ctx>, Type)>,
|
||||
I: IntoIterator<Item = AnyObject<'ctx>>,
|
||||
{
|
||||
let (vals, tys): (Vec<_>, Vec<_>) = items.into_iter().unzip();
|
||||
let (values, tys): (Vec<_>, Vec<_>) =
|
||||
objects.into_iter().map(|object| (object.value, object.ty)).unzip();
|
||||
|
||||
// let tuple_ty = ctx.unifier.add_ty(TypeEnum::TTuple { ty: tys });
|
||||
let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec();
|
||||
let llvm_tuple_ty = ctx.ctx.struct_type(&llvm_tys, false);
|
||||
|
||||
let pllvm_tuple = ctx.builder.build_alloca(llvm_tuple_ty, "tuple").unwrap();
|
||||
for (i, val) in vals.into_iter().enumerate() {
|
||||
// Store the dim value into the tuple
|
||||
for (i, val) in values.into_iter().enumerate() {
|
||||
let pval = ctx.builder.build_struct_gep(pllvm_tuple, i as u32, "value").unwrap();
|
||||
ctx.builder.build_store(pval, val).unwrap();
|
||||
}
|
||||
|
@ -36,4 +70,22 @@ impl<'ctx> TupleObject<'ctx> {
|
|||
let value = ctx.builder.build_load(pllvm_tuple, name).unwrap().into_struct_value();
|
||||
TupleObject { tys, value }
|
||||
}
|
||||
|
||||
/// Get the `len()` of this tuple.
|
||||
///
|
||||
/// We statically know the lengths of tuples in NAC3.
|
||||
pub fn len(&self) -> usize {
|
||||
self.tys.len()
|
||||
}
|
||||
|
||||
/// Get the `i`-th (0-based) object in this tuple.
|
||||
pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> {
|
||||
if i >= self.len() {
|
||||
panic!("Tuple object with length {} have index {i}", self.len());
|
||||
}
|
||||
|
||||
let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap();
|
||||
let ty = self.tys[i];
|
||||
AnyObject { value, ty }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ use super::model::*;
|
|||
use super::object::ndarray::indexing::util::gen_ndarray_subscript_ndindexes;
|
||||
use super::object::ndarray::scalar::split_scalar_or_ndarray;
|
||||
use super::object::ndarray::NDArrayObject;
|
||||
use super::object::AnyObject;
|
||||
use super::{
|
||||
super::symbol_resolver::ValueEnum,
|
||||
expr::destructure_range,
|
||||
|
@ -411,13 +412,14 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
|||
.gen_expr(ctx, target)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, target_ty)?;
|
||||
let target = NDArrayObject::from_value_and_type(generator, ctx, target, target_ty);
|
||||
let target = AnyObject { value: target, ty: target_ty };
|
||||
|
||||
// Process key
|
||||
let key = gen_ndarray_subscript_ndindexes(generator, ctx, key)?;
|
||||
|
||||
// Process value
|
||||
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
|
||||
let value = AnyObject { value, ty: value_ty };
|
||||
|
||||
/*
|
||||
Reference code:
|
||||
|
@ -433,9 +435,10 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
|||
...and finally copy 1-1 from value to target.
|
||||
```
|
||||
*/
|
||||
let target = NDArrayObject::from_object(generator, ctx, target);
|
||||
let target = target.index(generator, ctx, &key, "assign_target_ndarray");
|
||||
let value =
|
||||
split_scalar_or_ndarray(generator, ctx, value, value_ty).as_ndarray(generator, ctx);
|
||||
|
||||
let value = split_scalar_or_ndarray(generator, ctx, value).as_ndarray(generator, ctx);
|
||||
|
||||
let broadcast_result = NDArrayObject::broadcast(generator, ctx, &[target, value]);
|
||||
|
||||
|
|
|
@ -18,10 +18,13 @@ use crate::{
|
|||
extern_fns, irrt, llvm_intrinsics,
|
||||
numpy::*,
|
||||
numpy_new::{self, gen_ndarray_transpose},
|
||||
object::ndarray::{
|
||||
object::{
|
||||
ndarray::{
|
||||
functions::{FloorOrCeil, MinOrMax},
|
||||
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
|
||||
},
|
||||
AnyObject,
|
||||
},
|
||||
stmt::exn_constructor,
|
||||
},
|
||||
symbol_resolver::SymbolValue,
|
||||
|
@ -1080,6 +1083,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
move |ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
let arg = AnyObject { ty: arg_ty, value: arg };
|
||||
|
||||
let ret_dtype = match prim {
|
||||
PrimDef::FunInt32 => ctx.primitives.int32,
|
||||
|
@ -1091,7 +1095,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||
generator,
|
||||
ctx,
|
||||
ret_dtype,
|
||||
|
@ -1158,10 +1162,11 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
let arg = AnyObject { ty: arg_ty, value: arg };
|
||||
|
||||
let ret_int_dtype = size_variant.of_int(&ctx.primitives);
|
||||
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||
generator,
|
||||
ctx,
|
||||
ret_int_dtype,
|
||||
|
@ -1225,8 +1230,9 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
let arg = AnyObject { ty: arg_ty, value: arg };
|
||||
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||
generator,
|
||||
ctx,
|
||||
int_sized,
|
||||
|
@ -1618,6 +1624,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
let arg = AnyObject { ty: arg_ty, value: arg };
|
||||
|
||||
let kind = match prim {
|
||||
PrimDef::FunNpFloor => FloorOrCeil::Floor,
|
||||
|
@ -1625,7 +1632,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
|
@ -1652,8 +1659,9 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
Box::new(|ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
let arg = AnyObject { ty: arg_ty, value: arg };
|
||||
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
|
@ -1790,8 +1798,9 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let a_ty = fun.0.args[0].ty;
|
||||
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
|
||||
let a = AnyObject { ty: a_ty, value: a };
|
||||
|
||||
let a = split_scalar_or_ndarray(generator, ctx, a, a_ty).as_ndarray(generator, ctx);
|
||||
let a = split_scalar_or_ndarray(generator, ctx, a).as_ndarray(generator, ctx);
|
||||
let result = match prim {
|
||||
PrimDef::FunNpArgmin => a
|
||||
.argmin_or_argmax(generator, ctx, MinOrMax::Min)
|
||||
|
@ -1845,9 +1854,12 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1_val };
|
||||
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
let x2 = AnyObject { ty: x2_ty, value: x2_val };
|
||||
|
||||
let kind = match prim {
|
||||
PrimDef::FunNpMinimum => MinOrMax::Min,
|
||||
|
@ -1855,8 +1867,8 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let x1 = split_scalar_or_ndarray(generator, ctx, x1_val, x1_ty);
|
||||
let x2 = split_scalar_or_ndarray(generator, ctx, x2_val, x2_ty);
|
||||
let x1 = split_scalar_or_ndarray(generator, ctx, x1);
|
||||
let x2 = split_scalar_or_ndarray(generator, ctx, x2);
|
||||
|
||||
// NOTE: x1.dtype() and x2.dtype() should be the same
|
||||
let common_ty = x1.dtype();
|
||||
|
@ -1907,8 +1919,9 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
move |ctx, _, fun, args, generator| {
|
||||
let n_ty = fun.0.args[0].ty;
|
||||
let n_val = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
|
||||
let n = AnyObject { ty: n_ty, value: n_val };
|
||||
|
||||
let result = split_scalar_or_ndarray(generator, ctx, n_val, n_ty).map(
|
||||
let result = split_scalar_or_ndarray(generator, ctx, n).map(
|
||||
generator,
|
||||
ctx,
|
||||
num_ty.ty,
|
||||
|
@ -1936,6 +1949,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let x_ty = fun.0.args[0].ty;
|
||||
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
||||
let x = AnyObject { value: x_val, ty: x_ty };
|
||||
|
||||
let function = match prim {
|
||||
PrimDef::FunNpIsInf => irrt::call_isnan,
|
||||
|
@ -1943,7 +1957,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let result = split_scalar_or_ndarray(generator, ctx, x_val, x_ty).map(
|
||||
let result = split_scalar_or_ndarray(generator, ctx, x).map(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.bool,
|
||||
|
@ -2010,8 +2024,9 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
Box::new(move |ctx, _, fun, args, generator| {
|
||||
let arg_ty = fun.0.args[0].ty;
|
||||
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||
let arg = AnyObject { ty: arg_ty, value: arg_val };
|
||||
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg_val, arg_ty).map(
|
||||
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.float,
|
||||
|
@ -2126,12 +2141,14 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
move |ctx, _, fun, args, generator| {
|
||||
let x1_ty = fun.0.args[0].ty;
|
||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||
let x1 = AnyObject { ty: x1_ty, value: x1_val };
|
||||
|
||||
let x2_ty = fun.0.args[1].ty;
|
||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||
let x2 = AnyObject { ty: x2_ty, value: x2_val };
|
||||
|
||||
let x1 = split_scalar_or_ndarray(generator, ctx, x1_val, x1_ty);
|
||||
let x2 = split_scalar_or_ndarray(generator, ctx, x2_val, x2_ty);
|
||||
let x1 = split_scalar_or_ndarray(generator, ctx, x1);
|
||||
let x2 = split_scalar_or_ndarray(generator, ctx, x2);
|
||||
|
||||
let result = ScalarOrNDArray::broadcasting_starmap(
|
||||
generator,
|
||||
|
|
Loading…
Reference in New Issue