forked from M-Labs/nac3
WIP: core/ndstrides: AnyObject + TupleObject
This commit is contained in:
parent
febe78b6a4
commit
2fbe981701
|
@ -1,15 +1,14 @@
|
||||||
use inkwell::types::BasicTypeEnum;
|
use inkwell::types::BasicTypeEnum;
|
||||||
use inkwell::values::{BasicValue, BasicValueEnum, IntValue, PointerValue};
|
use inkwell::values::{BasicValue, BasicValueEnum, IntValue, PointerValue};
|
||||||
use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
use inkwell::IntPredicate;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::classes::{
|
use crate::codegen::classes::{
|
||||||
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
||||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
UntypedArrayLikeAccessor,
|
||||||
};
|
};
|
||||||
use crate::codegen::expr::destructure_range;
|
use crate::codegen::expr::destructure_range;
|
||||||
use crate::codegen::irrt::calculate_len_for_slice_range;
|
use crate::codegen::irrt::calculate_len_for_slice_range;
|
||||||
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
|
||||||
use crate::codegen::{extern_fns, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
use crate::codegen::{extern_fns, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
|
||||||
use crate::toplevel::helper::PrimDef;
|
use crate::toplevel::helper::PrimDef;
|
||||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||||
|
|
|
@ -12,7 +12,10 @@ use crate::{
|
||||||
call_memcpy_generic,
|
call_memcpy_generic,
|
||||||
},
|
},
|
||||||
need_sret, numpy,
|
need_sret, numpy,
|
||||||
object::ndarray::{scalar::split_scalar_or_ndarray, NDArrayObject, NDArrayOut},
|
object::{
|
||||||
|
ndarray::{scalar::split_scalar_or_ndarray, NDArrayObject, NDArrayOut},
|
||||||
|
AnyObject,
|
||||||
|
},
|
||||||
stmt::{
|
stmt::{
|
||||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||||
gen_var,
|
gen_var,
|
||||||
|
@ -1537,10 +1540,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let left =
|
let left = AnyObject { ty: ty1, value: left_val };
|
||||||
split_scalar_or_ndarray(generator, ctx, left_val, ty1).as_ndarray(generator, ctx);
|
let right = AnyObject { ty: ty1, value: right_val };
|
||||||
let right =
|
|
||||||
split_scalar_or_ndarray(generator, ctx, right_val, ty2).as_ndarray(generator, ctx);
|
let left = split_scalar_or_ndarray(generator, ctx, left).as_ndarray(generator, ctx);
|
||||||
|
let right = split_scalar_or_ndarray(generator, ctx, right).as_ndarray(generator, ctx);
|
||||||
|
|
||||||
debug_assert!(ctx.unifier.unioned(left.dtype, right.dtype)); // Typechecker ensures this.
|
debug_assert!(ctx.unifier.unioned(left.dtype, right.dtype)); // Typechecker ensures this.
|
||||||
|
|
||||||
|
@ -2860,8 +2864,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
|
|
||||||
let ndarray_ty = value.custom.unwrap();
|
let ndarray_ty = value.custom.unwrap();
|
||||||
let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
let ndarray = ndarray.to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||||
let ndarray =
|
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||||
NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
|
|
||||||
|
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||||
|
|
||||||
let indexes = gen_ndarray_subscript_ndindexes(generator, ctx, slice)?;
|
let indexes = gen_ndarray_subscript_ndindexes(generator, ctx, slice)?;
|
||||||
let result = ndarray
|
let result = ndarray
|
||||||
|
|
|
@ -13,7 +13,10 @@ use crate::{
|
||||||
use super::{
|
use super::{
|
||||||
irrt::call_nac3_ndarray_util_assert_shape_no_negative,
|
irrt::call_nac3_ndarray_util_assert_shape_no_negative,
|
||||||
model::*,
|
model::*,
|
||||||
object::ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
|
object::{
|
||||||
|
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
|
||||||
|
AnyObject,
|
||||||
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -84,13 +87,12 @@ fn create_empty_ndarray<'ctx, G>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
ndarray_ty: Type,
|
ndarray_ty: Type,
|
||||||
shape: BasicValueEnum<'ctx>,
|
shape: AnyObject<'ctx>,
|
||||||
shape_ty: Type,
|
|
||||||
) -> NDArrayObject<'ctx>
|
) -> NDArrayObject<'ctx>
|
||||||
where
|
where
|
||||||
G: CodeGenerator + ?Sized,
|
G: CodeGenerator + ?Sized,
|
||||||
{
|
{
|
||||||
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
|
let (_, shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||||
|
|
||||||
let ndarray =
|
let ndarray =
|
||||||
NDArrayObject::alloca_uninitialized_of_type(generator, ctx, ndarray_ty, "ndarray");
|
NDArrayObject::alloca_uninitialized_of_type(generator, ctx, ndarray_ty, "ndarray");
|
||||||
|
@ -120,10 +122,11 @@ pub fn gen_ndarray_empty<'ctx>(
|
||||||
// Parse arguments
|
// Parse arguments
|
||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||||
|
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||||
|
|
||||||
// Implementation
|
// Implementation
|
||||||
let ndarray_ty = fun.0.ret;
|
let ndarray_ty = fun.0.ret;
|
||||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
|
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
|
||||||
|
|
||||||
Ok(ndarray.instance.value.as_basic_value_enum())
|
Ok(ndarray.instance.value.as_basic_value_enum())
|
||||||
}
|
}
|
||||||
|
@ -142,10 +145,11 @@ pub fn gen_ndarray_zeros<'ctx>(
|
||||||
// Parse arguments
|
// Parse arguments
|
||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||||
|
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||||
|
|
||||||
// Implementation
|
// Implementation
|
||||||
let ndarray_ty = fun.0.ret;
|
let ndarray_ty = fun.0.ret;
|
||||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
|
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
|
||||||
|
|
||||||
let fill_value = ndarray_zero_value(generator, ctx, ndarray.dtype);
|
let fill_value = ndarray_zero_value(generator, ctx, ndarray.dtype);
|
||||||
ndarray.fill(generator, ctx, fill_value);
|
ndarray.fill(generator, ctx, fill_value);
|
||||||
|
@ -167,10 +171,11 @@ pub fn gen_ndarray_ones<'ctx>(
|
||||||
// Parse arguments
|
// Parse arguments
|
||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||||
|
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||||
|
|
||||||
// Implementation
|
// Implementation
|
||||||
let ndarray_ty = fun.0.ret;
|
let ndarray_ty = fun.0.ret;
|
||||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
|
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
|
||||||
|
|
||||||
let fill_value = ndarray_one_value(generator, ctx, ndarray.dtype);
|
let fill_value = ndarray_one_value(generator, ctx, ndarray.dtype);
|
||||||
ndarray.fill(generator, ctx, fill_value);
|
ndarray.fill(generator, ctx, fill_value);
|
||||||
|
@ -192,6 +197,7 @@ pub fn gen_ndarray_full<'ctx>(
|
||||||
// Parse argument #1 shape
|
// Parse argument #1 shape
|
||||||
let shape_ty = fun.0.args[0].ty;
|
let shape_ty = fun.0.args[0].ty;
|
||||||
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
let shape = args[0].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||||
|
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||||
|
|
||||||
// Parse argument #2 fill_value
|
// Parse argument #2 fill_value
|
||||||
let fill_value_ty = fun.0.args[1].ty;
|
let fill_value_ty = fun.0.args[1].ty;
|
||||||
|
@ -199,7 +205,7 @@ pub fn gen_ndarray_full<'ctx>(
|
||||||
|
|
||||||
// Implementation
|
// Implementation
|
||||||
let ndarray_ty = fun.0.ret;
|
let ndarray_ty = fun.0.ret;
|
||||||
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape, shape_ty);
|
let ndarray = create_empty_ndarray(generator, ctx, ndarray_ty, shape);
|
||||||
|
|
||||||
ndarray.fill(generator, ctx, fill_value);
|
ndarray.fill(generator, ctx, fill_value);
|
||||||
|
|
||||||
|
@ -220,10 +226,12 @@ pub fn gen_ndarray_broadcast_to<'ctx>(
|
||||||
// Parse argument #1 input
|
// Parse argument #1 input
|
||||||
let input_ty = fun.0.args[0].ty;
|
let input_ty = fun.0.args[0].ty;
|
||||||
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
|
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
|
||||||
|
let input = AnyObject { ty: input_ty, value: input };
|
||||||
|
|
||||||
// Parse argument #2 shape
|
// Parse argument #2 shape
|
||||||
let shape_ty = fun.0.args[1].ty;
|
let shape_ty = fun.0.args[1].ty;
|
||||||
let shape = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
let shape = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||||
|
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||||
|
|
||||||
// Define models
|
// Define models
|
||||||
let sizet_model = IntModel(SizeT);
|
let sizet_model = IntModel(SizeT);
|
||||||
|
@ -234,11 +242,10 @@ pub fn gen_ndarray_broadcast_to<'ctx>(
|
||||||
let broadcast_ndims = extract_ndims(&ctx.unifier, broadcast_ndims_ty);
|
let broadcast_ndims = extract_ndims(&ctx.unifier, broadcast_ndims_ty);
|
||||||
|
|
||||||
// Process `input`
|
// Process `input`
|
||||||
let in_ndarray =
|
let in_ndarray = split_scalar_or_ndarray(generator, ctx, input).as_ndarray(generator, ctx);
|
||||||
split_scalar_or_ndarray(generator, ctx, input, input_ty).as_ndarray(generator, ctx);
|
|
||||||
|
|
||||||
// Process `shape`
|
// Process `shape`
|
||||||
let (_, broadcast_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
|
let (_, broadcast_shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||||
// NOTE: shape.size should equal to `broadcasted_ndims`.
|
// NOTE: shape.size should equal to `broadcasted_ndims`.
|
||||||
let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
|
let broadcast_ndims_llvm = sizet_model.constant(generator, ctx.ctx, broadcast_ndims);
|
||||||
call_nac3_ndarray_util_assert_shape_no_negative(
|
call_nac3_ndarray_util_assert_shape_no_negative(
|
||||||
|
@ -269,23 +276,24 @@ pub fn gen_ndarray_reshape<'ctx>(
|
||||||
// Parse argument #1 input
|
// Parse argument #1 input
|
||||||
let input_ty = fun.0.args[0].ty;
|
let input_ty = fun.0.args[0].ty;
|
||||||
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
|
let input = args[0].1.clone().to_basic_value_enum(ctx, generator, input_ty)?;
|
||||||
|
let input = AnyObject { ty: input_ty, value: input };
|
||||||
|
|
||||||
// Parse argument #2 shape
|
// Parse argument #2 shape
|
||||||
let shape_ty = fun.0.args[1].ty;
|
let shape_ty = fun.0.args[1].ty;
|
||||||
let shape = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
let shape = args[1].1.clone().to_basic_value_enum(ctx, generator, shape_ty)?;
|
||||||
|
let shape = AnyObject { ty: shape_ty, value: shape };
|
||||||
|
|
||||||
// Extract reshaped_ndims
|
// Extract reshaped_ndims
|
||||||
let (_, reshaped_ndims_ty) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
let (_, reshaped_ndims_ty) = unpack_ndarray_var_tys(&mut ctx.unifier, fun.0.ret);
|
||||||
let reshaped_ndims = extract_ndims(&ctx.unifier, reshaped_ndims_ty);
|
let reshaped_ndims = extract_ndims(&ctx.unifier, reshaped_ndims_ty);
|
||||||
|
|
||||||
// Process `input`
|
// Process `input`
|
||||||
let in_ndarray =
|
let in_ndarray = split_scalar_or_ndarray(generator, ctx, input).as_ndarray(generator, ctx);
|
||||||
split_scalar_or_ndarray(generator, ctx, input, input_ty).as_ndarray(generator, ctx);
|
|
||||||
|
|
||||||
// Process the shape input from user and resolve negative indices.
|
// Process the shape input from user and resolve negative indices.
|
||||||
// The resulting `new_shape`'s size should be equal to reshaped_ndims.
|
// The resulting `new_shape`'s size should be equal to reshaped_ndims.
|
||||||
// This is ensured by the typechecker.
|
// This is ensured by the typechecker.
|
||||||
let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape, shape_ty);
|
let (_, new_shape) = parse_numpy_int_sequence(generator, ctx, shape);
|
||||||
let reshaped_ndarray = in_ndarray.reshape_or_copy(generator, ctx, reshaped_ndims, new_shape);
|
let reshaped_ndarray = in_ndarray.reshape_or_copy(generator, ctx, reshaped_ndims, new_shape);
|
||||||
|
|
||||||
Ok(reshaped_ndarray.instance.value.as_basic_value_enum())
|
Ok(reshaped_ndarray.instance.value.as_basic_value_enum())
|
||||||
|
@ -354,8 +362,8 @@ pub fn gen_ndarray_size<'ctx>(
|
||||||
|
|
||||||
let ndarray_ty = fun.0.args[0].ty;
|
let ndarray_ty = fun.0.args[0].ty;
|
||||||
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||||
|
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||||
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
|
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||||
|
|
||||||
let size = ndarray.size(generator, ctx).truncate(generator, ctx, Int32, "size");
|
let size = ndarray.size(generator, ctx).truncate(generator, ctx, Int32, "size");
|
||||||
Ok(size.value.as_basic_value_enum())
|
Ok(size.value.as_basic_value_enum())
|
||||||
|
@ -375,14 +383,15 @@ pub fn gen_ndarray_shape<'ctx>(
|
||||||
// Parse argument #1 ndarray
|
// Parse argument #1 ndarray
|
||||||
let ndarray_ty = fun.0.args[0].ty;
|
let ndarray_ty = fun.0.args[0].ty;
|
||||||
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||||
|
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||||
|
|
||||||
// Define models
|
// Define models
|
||||||
let sizet_model = IntModel(SizeT);
|
let sizet_model = IntModel(SizeT);
|
||||||
|
|
||||||
// Process ndarray
|
// Process ndarray
|
||||||
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
|
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||||
|
|
||||||
let mut items = Vec::with_capacity(ndarray.ndims as usize);
|
let mut objects = Vec::with_capacity(ndarray.ndims as usize);
|
||||||
|
|
||||||
for i in 0..ndarray.ndims {
|
for i in 0..ndarray.ndims {
|
||||||
let i = sizet_model.constant(generator, ctx.ctx, i);
|
let i = sizet_model.constant(generator, ctx.ctx, i);
|
||||||
|
@ -392,10 +401,11 @@ pub fn gen_ndarray_shape<'ctx>(
|
||||||
.ix(generator, ctx, i.value, "dim");
|
.ix(generator, ctx, i.value, "dim");
|
||||||
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
|
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
|
||||||
|
|
||||||
items.push((dim.value.as_basic_value_enum(), ctx.primitives.int32));
|
objects
|
||||||
|
.push(AnyObject { ty: ctx.primitives.int32, value: dim.value.as_basic_value_enum() });
|
||||||
}
|
}
|
||||||
|
|
||||||
let shape = TupleObject::create(generator, ctx, items, "shape");
|
let shape = TupleObject::create(generator, ctx, objects, "shape");
|
||||||
Ok(shape.value.as_basic_value_enum())
|
Ok(shape.value.as_basic_value_enum())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -407,7 +417,7 @@ pub fn gen_ndarray_strides<'ctx>(
|
||||||
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
args: &[(Option<StrRef>, ValueEnum<'ctx>)],
|
||||||
generator: &mut dyn CodeGenerator,
|
generator: &mut dyn CodeGenerator,
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
// TODO: This function looks exactly like `gen_ndarray_shapes`, code duplication?
|
// TODO: Code duplication: This function looks exactly like `gen_ndarray_shapes`.
|
||||||
|
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
assert_eq!(args.len(), 1);
|
assert_eq!(args.len(), 1);
|
||||||
|
@ -415,14 +425,15 @@ pub fn gen_ndarray_strides<'ctx>(
|
||||||
// Parse argument #1 ndarray
|
// Parse argument #1 ndarray
|
||||||
let ndarray_ty = fun.0.args[0].ty;
|
let ndarray_ty = fun.0.args[0].ty;
|
||||||
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||||
|
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||||
|
|
||||||
// Define models
|
// Define models
|
||||||
let sizet_model = IntModel(SizeT);
|
let sizet_model = IntModel(SizeT);
|
||||||
|
|
||||||
// Process ndarray
|
// Process ndarray
|
||||||
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
|
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||||
|
|
||||||
let mut items = Vec::with_capacity(ndarray.ndims as usize);
|
let mut objects = Vec::with_capacity(ndarray.ndims as usize);
|
||||||
|
|
||||||
for i in 0..ndarray.ndims {
|
for i in 0..ndarray.ndims {
|
||||||
let i = sizet_model.constant(generator, ctx.ctx, i);
|
let i = sizet_model.constant(generator, ctx.ctx, i);
|
||||||
|
@ -432,10 +443,11 @@ pub fn gen_ndarray_strides<'ctx>(
|
||||||
.ix(generator, ctx, i.value, "dim");
|
.ix(generator, ctx, i.value, "dim");
|
||||||
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
|
let dim = dim.truncate(generator, ctx, Int32, "dim"); // TODO: keep using SizeT
|
||||||
|
|
||||||
items.push((dim.value.as_basic_value_enum(), ctx.primitives.int32));
|
objects
|
||||||
|
.push(AnyObject { ty: ctx.primitives.int32, value: dim.value.as_basic_value_enum() });
|
||||||
}
|
}
|
||||||
|
|
||||||
let strides = TupleObject::create(generator, ctx, items, "strides");
|
let strides = TupleObject::create(generator, ctx, objects, "strides");
|
||||||
Ok(strides.value.as_basic_value_enum())
|
Ok(strides.value.as_basic_value_enum())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -458,20 +470,23 @@ pub fn gen_ndarray_transpose<'ctx>(
|
||||||
// Parse argument #1 ndarray
|
// Parse argument #1 ndarray
|
||||||
let ndarray_ty = fun.0.args[0].ty;
|
let ndarray_ty = fun.0.args[0].ty;
|
||||||
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
let ndarray = args[0].1.clone().to_basic_value_enum(ctx, generator, ndarray_ty)?;
|
||||||
|
let ndarray = AnyObject { ty: ndarray_ty, value: ndarray };
|
||||||
|
|
||||||
// Implementation
|
// Implementation
|
||||||
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, ndarray, ndarray_ty);
|
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||||
|
|
||||||
let has_axes = args.len() >= 2;
|
let has_axes = args.len() >= 2;
|
||||||
let transposed_ndarray = if has_axes {
|
let transposed_ndarray = if has_axes {
|
||||||
// Parse argument #2 axes
|
// Parse argument #2 axes
|
||||||
let in_axes_ty = fun.0.args[1].ty;
|
let in_axes_ty = fun.0.args[1].ty;
|
||||||
let in_axes = args[1].1.clone().to_basic_value_enum(ctx, generator, in_axes_ty)?;
|
let in_axes = args[1].1.clone().to_basic_value_enum(ctx, generator, in_axes_ty)?;
|
||||||
|
let in_axes = AnyObject { ty: in_axes_ty, value: in_axes };
|
||||||
|
|
||||||
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes, in_axes_ty);
|
let (_, axes) = parse_numpy_int_sequence(generator, ctx, in_axes);
|
||||||
|
|
||||||
ndarray.transpose(generator, ctx, Some(axes))
|
ndarray.transpose(generator, ctx, Some(axes))
|
||||||
} else {
|
} else {
|
||||||
|
// axes is not given
|
||||||
ndarray.transpose(generator, ctx, None)
|
ndarray.transpose(generator, ctx, None)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -488,8 +503,9 @@ pub fn gen_ndarray_array<'ctx>(
|
||||||
assert!(obj.is_none());
|
assert!(obj.is_none());
|
||||||
assert!(matches!(args.len(), 1..=3));
|
assert!(matches!(args.len(), 1..=3));
|
||||||
|
|
||||||
let obj_ty = fun.0.args[0].ty;
|
let object_ty = fun.0.args[0].ty;
|
||||||
let obj_arg = args[0].1.clone().to_basic_value_enum(ctx, generator, obj_ty)?;
|
let object = args[0].1.clone().to_basic_value_enum(ctx, generator, object_ty)?;
|
||||||
|
let object = AnyObject { ty: object_ty, value: object };
|
||||||
|
|
||||||
let copy_arg = if let Some(arg) =
|
let copy_arg = if let Some(arg) =
|
||||||
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
args.iter().find(|arg| arg.0.is_some_and(|name| name == fun.0.args[1].name))
|
||||||
|
@ -513,7 +529,7 @@ pub fn gen_ndarray_array<'ctx>(
|
||||||
let copy = IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8
|
let copy = IntModel(Byte).check_value(generator, ctx.ctx, copy_arg).unwrap(); // NAC3 booleans are i8
|
||||||
let copy = copy.truncate(generator, ctx, Bool, "copy_bool");
|
let copy = copy.truncate(generator, ctx, Bool, "copy_bool");
|
||||||
|
|
||||||
let ndarray = NDArrayObject::from_np_array(generator, ctx, obj_arg, obj_ty, copy);
|
let ndarray = NDArrayObject::from_np_array(generator, ctx, object, copy);
|
||||||
debug_assert!(ndarray.ndims <= output_ndims); // Sanity check on `ndims`
|
debug_assert!(ndarray.ndims <= output_ndims); // Sanity check on `ndims`
|
||||||
|
|
||||||
let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims);
|
let ndarray = ndarray.atleast_nd(generator, ctx, output_ndims);
|
||||||
|
|
|
@ -5,6 +5,8 @@ use crate::{
|
||||||
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
|
typecheck::typedef::{iter_type_vars, Type, TypeEnum},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::AnyObject;
|
||||||
|
|
||||||
/// A NAC3 Python List object.
|
/// A NAC3 Python List object.
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct ListObject<'ctx> {
|
pub struct ListObject<'ctx> {
|
||||||
|
@ -15,21 +17,20 @@ pub struct ListObject<'ctx> {
|
||||||
|
|
||||||
impl<'ctx> ListObject<'ctx> {
|
impl<'ctx> ListObject<'ctx> {
|
||||||
/// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`].
|
/// Create a [`ListObject`] from an LLVM value and its typechecker [`Type`].
|
||||||
pub fn from_value_and_type<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
|
pub fn from_object<G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
list_val: V,
|
object: AnyObject<'ctx>,
|
||||||
list_type: Type,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
// Check typechecker type and extract `item_type`
|
// Check typechecker type and extract `item_type`
|
||||||
let item_type = match &*ctx.unifier.get_ty(list_type) {
|
let item_type = match &*ctx.unifier.get_ty(object.ty) {
|
||||||
TypeEnum::TObj { obj_id, params, .. }
|
TypeEnum::TObj { obj_id, params, .. }
|
||||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
iter_type_vars(params).next().unwrap().ty // Extract `item_type`
|
iter_type_vars(params).next().unwrap().ty // Extract `item_type`
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(list_type))
|
panic!("Expecting type to be a list, but got {}", ctx.unifier.stringify(object.ty))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -37,7 +38,7 @@ impl<'ctx> ListObject<'ctx> {
|
||||||
let plist_model = PtrModel(StructModel(List { item: item_model }));
|
let plist_model = PtrModel(StructModel(List { item: item_model }));
|
||||||
|
|
||||||
// Create object
|
// Create object
|
||||||
let value = plist_model.check_value(generator, ctx.ctx, list_val).unwrap();
|
let value = plist_model.check_value(generator, ctx.ctx, object.value).unwrap();
|
||||||
ListObject { item_type, instance: value }
|
ListObject { item_type, instance: value }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,13 @@
|
||||||
|
use inkwell::values::BasicValueEnum;
|
||||||
|
|
||||||
|
use crate::typecheck::typedef::Type;
|
||||||
|
|
||||||
pub mod list;
|
pub mod list;
|
||||||
pub mod ndarray;
|
pub mod ndarray;
|
||||||
pub mod tuple;
|
pub mod tuple;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct AnyObject<'ctx> {
|
||||||
|
pub ty: Type,
|
||||||
|
pub value: BasicValueEnum<'ctx>,
|
||||||
|
}
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
use inkwell::values::BasicValueEnum;
|
|
||||||
|
|
||||||
use super::NDArrayObject;
|
use super::NDArrayObject;
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
irrt::{call_nac3_array_set_and_validate_list_shape, call_nac3_array_write_list_to_array},
|
irrt::{call_nac3_array_set_and_validate_list_shape, call_nac3_array_write_list_to_array},
|
||||||
model::*,
|
model::*,
|
||||||
object::list::ListObject,
|
object::{list::ListObject, AnyObject},
|
||||||
stmt::gen_if_else_expr_callback,
|
stmt::gen_if_else_expr_callback,
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
|
@ -26,7 +24,7 @@ fn get_list_object_dtype_and_ndims<'ctx>(
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
fn from_np_array_list_copy<G: CodeGenerator + ?Sized>(
|
fn from_np_array_list_copy_impl<G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
list: ListObject<'ctx>,
|
list: ListObject<'ctx>,
|
||||||
|
@ -59,7 +57,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
ndarray
|
ndarray
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_np_array_list_try_no_copy<G: CodeGenerator + ?Sized>(
|
fn from_np_array_list_try_no_copy_impl<G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
list: ListObject<'ctx>,
|
list: ListObject<'ctx>,
|
||||||
|
@ -101,11 +99,11 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
ndarray
|
ndarray
|
||||||
} else {
|
} else {
|
||||||
// `list` is nested, it is impossible to not copy.
|
// `list` is nested, it is impossible to not copy.
|
||||||
NDArrayObject::from_np_array_list_copy(generator, ctx, list)
|
NDArrayObject::from_np_array_list_copy_impl(generator, ctx, list)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_np_array_list<G: CodeGenerator + ?Sized>(
|
fn from_np_array_list_impl<G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
list: ListObject<'ctx>,
|
list: ListObject<'ctx>,
|
||||||
|
@ -118,11 +116,12 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
ctx,
|
ctx,
|
||||||
|_generator, _ctx| Ok(copy.value),
|
|_generator, _ctx| Ok(copy.value),
|
||||||
|generator, ctx| {
|
|generator, ctx| {
|
||||||
let ndarray = NDArrayObject::from_np_array_list_copy(generator, ctx, list);
|
let ndarray = NDArrayObject::from_np_array_list_copy_impl(generator, ctx, list);
|
||||||
Ok(Some(ndarray.instance.value))
|
Ok(Some(ndarray.instance.value))
|
||||||
},
|
},
|
||||||
|generator, ctx| {
|
|generator, ctx| {
|
||||||
let ndarray = NDArrayObject::from_np_array_list_try_no_copy(generator, ctx, list);
|
let ndarray =
|
||||||
|
NDArrayObject::from_np_array_list_try_no_copy_impl(generator, ctx, list);
|
||||||
Ok(Some(ndarray.instance.value))
|
Ok(Some(ndarray.instance.value))
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -132,7 +131,7 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
NDArrayObject::from_value_and_unpacked_types(generator, ctx, ndarray, dtype, ndims)
|
NDArrayObject::from_value_and_unpacked_types(generator, ctx, ndarray, dtype, ndims)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn from_np_array_ndarray<G: CodeGenerator + ?Sized>(
|
pub fn from_np_array_ndarray_impl<G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
ndarray: NDArrayObject<'ctx>,
|
ndarray: NDArrayObject<'ctx>,
|
||||||
|
@ -166,24 +165,23 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
pub fn from_np_array<G: CodeGenerator + ?Sized>(
|
pub fn from_np_array<G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
object: BasicValueEnum<'ctx>,
|
object: AnyObject<'ctx>,
|
||||||
object_ty: Type,
|
|
||||||
copy: Int<'ctx, Bool>,
|
copy: Int<'ctx, Bool>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
match &*ctx.unifier.get_ty(object_ty) {
|
match &*ctx.unifier.get_ty(object.ty) {
|
||||||
TypeEnum::TObj { obj_id, .. }
|
TypeEnum::TObj { obj_id, .. }
|
||||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
let list = ListObject::from_value_and_type(generator, ctx, object, object_ty);
|
let list = ListObject::from_object(generator, ctx, object);
|
||||||
NDArrayObject::from_np_array_list(generator, ctx, list, copy)
|
NDArrayObject::from_np_array_list_impl(generator, ctx, list, copy)
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, .. }
|
TypeEnum::TObj { obj_id, .. }
|
||||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, object, object_ty);
|
let ndarray = NDArrayObject::from_object(generator, ctx, object);
|
||||||
NDArrayObject::from_np_array_ndarray(generator, ctx, ndarray, copy)
|
NDArrayObject::from_np_array_ndarray_impl(generator, ctx, ndarray, copy)
|
||||||
}
|
}
|
||||||
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object_ty)), // Typechecker ensures this
|
_ => panic!("Unrecognized object type: {}", ctx.unifier.stringify(object.ty)), // Typechecker ensures this
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,6 +35,8 @@ use inkwell::{
|
||||||
use scalar::{ScalarObject, ScalarOrNDArray};
|
use scalar::{ScalarObject, ScalarOrNDArray};
|
||||||
use util::{call_memcpy_model, gen_for_model_auto};
|
use util::{call_memcpy_model, gen_for_model_auto};
|
||||||
|
|
||||||
|
use super::AnyObject;
|
||||||
|
|
||||||
/// A NAC3 Python ndarray object.
|
/// A NAC3 Python ndarray object.
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct NDArrayObject<'ctx> {
|
pub struct NDArrayObject<'ctx> {
|
||||||
|
@ -45,18 +47,17 @@ pub struct NDArrayObject<'ctx> {
|
||||||
|
|
||||||
impl<'ctx> NDArrayObject<'ctx> {
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
/// Create an [`NDArrayObject`] from an LLVM value and its typechecker [`Type`].
|
/// Create an [`NDArrayObject`] from an LLVM value and its typechecker [`Type`].
|
||||||
pub fn from_value_and_type<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
|
pub fn from_object<G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
value: V,
|
object: AnyObject<'ctx>,
|
||||||
ty: Type,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
let (dtype, ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, object.ty);
|
||||||
let ndims = extract_ndims(&ctx.unifier, ndims);
|
let ndims = extract_ndims(&ctx.unifier, ndims);
|
||||||
Self::from_value_and_unpacked_types(generator, ctx, value, dtype, ndims)
|
Self::from_value_and_unpacked_types(generator, ctx, object.value, dtype, ndims)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Like [`NDArrayObject::from_value_and_type`] but you directly supply the ndarray's
|
/// Like [`NDArrayObject::from_object`] but you directly supply the ndarray's
|
||||||
/// `dtype` and `ndims`.
|
/// `dtype` and `ndims`.
|
||||||
pub fn from_value_and_unpacked_types<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
|
pub fn from_value_and_unpacked_types<V: BasicValue<'ctx>, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use inkwell::values::{BasicValue, BasicValueEnum};
|
use inkwell::values::{BasicValue, BasicValueEnum};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{model::*, CodeGenContext, CodeGenerator},
|
codegen::{model::*, object::AnyObject, CodeGenContext, CodeGenerator},
|
||||||
typecheck::typedef::{Type, TypeEnum},
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -120,23 +120,22 @@ impl<'ctx> TryFrom<&ScalarOrNDArray<'ctx>> for NDArrayObject<'ctx> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Split an [`BasicValueEnum<'ctx>`] into a [`ScalarOrNDArray`] depending
|
/// Split an [`AnyObject`] into a [`ScalarOrNDArray`] depending
|
||||||
/// on its [`Type`].
|
/// on its [`Type`].
|
||||||
pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn split_scalar_or_ndarray<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
input: BasicValueEnum<'ctx>,
|
object: AnyObject<'ctx>,
|
||||||
input_ty: Type,
|
|
||||||
) -> ScalarOrNDArray<'ctx> {
|
) -> ScalarOrNDArray<'ctx> {
|
||||||
match &*ctx.unifier.get_ty(input_ty) {
|
match &*ctx.unifier.get_ty(object.ty) {
|
||||||
TypeEnum::TObj { obj_id, .. }
|
TypeEnum::TObj { obj_id, .. }
|
||||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
let ndarray = NDArrayObject::from_value_and_type(generator, ctx, input, input_ty);
|
let ndarray = NDArrayObject::from_object(generator, ctx, object);
|
||||||
ScalarOrNDArray::NDArray(ndarray)
|
ScalarOrNDArray::NDArray(ndarray)
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let scalar = ScalarObject { dtype: input_ty, instance: input };
|
let scalar = ScalarObject { dtype: object.ty, instance: object.value };
|
||||||
ScalarOrNDArray::Scalar(scalar)
|
ScalarOrNDArray::Scalar(scalar)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
use inkwell::values::BasicValueEnum;
|
|
||||||
use util::gen_for_model_auto;
|
use util::gen_for_model_auto;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{model::*, object::list::ListObject, CodeGenContext, CodeGenerator},
|
codegen::{
|
||||||
typecheck::typedef::{Type, TypeEnum},
|
model::*,
|
||||||
|
object::{list::ListObject, tuple::TupleObject, AnyObject},
|
||||||
|
CodeGenContext, CodeGenerator,
|
||||||
|
},
|
||||||
|
typecheck::typedef::TypeEnum,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
|
/// Parse a NumPy-like "int sequence" input and return the int sequence as an array and its length.
|
||||||
|
@ -20,8 +23,7 @@ use crate::{
|
||||||
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
input_sequence: BasicValueEnum<'ctx>,
|
input_sequence: AnyObject<'ctx>,
|
||||||
input_sequence_ty: Type,
|
|
||||||
) -> (Int<'ctx, SizeT>, Ptr<'ctx, IntModel<SizeT>>) {
|
) -> (Int<'ctx, SizeT>, Ptr<'ctx, IntModel<SizeT>>) {
|
||||||
let sizet_model = IntModel(SizeT);
|
let sizet_model = IntModel(SizeT);
|
||||||
|
|
||||||
|
@ -29,15 +31,14 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
let one = sizet_model.const_1(generator, ctx.ctx);
|
let one = sizet_model.const_1(generator, ctx.ctx);
|
||||||
|
|
||||||
// The result `list` to return.
|
// The result `list` to return.
|
||||||
match &*ctx.unifier.get_ty(input_sequence_ty) {
|
match &*ctx.unifier.get_ty(input_sequence.ty) {
|
||||||
TypeEnum::TObj { obj_id, .. }
|
TypeEnum::TObj { obj_id, .. }
|
||||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
|
||||||
|
|
||||||
// Check `input_sequence`
|
// Check `input_sequence`
|
||||||
let input_sequence =
|
let input_sequence = ListObject::from_object(generator, ctx, input_sequence);
|
||||||
ListObject::from_value_and_type(generator, ctx, input_sequence, input_sequence_ty);
|
|
||||||
|
|
||||||
let len = input_sequence.instance.gep(ctx, |f| f.len).load(generator, ctx, "len");
|
let len = input_sequence.instance.gep(ctx, |f| f.len).load(generator, ctx, "len");
|
||||||
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
|
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
|
||||||
|
@ -66,20 +67,17 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
}
|
}
|
||||||
TypeEnum::TTuple { ty: tuple_types, .. } => {
|
TypeEnum::TTuple { ty: tuple_types, .. } => {
|
||||||
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
|
||||||
let input_sequence = input_sequence.into_struct_value(); // A tuple is a struct
|
|
||||||
|
|
||||||
let len_int = tuple_types.len();
|
let input_sequence = TupleObject::from_object(ctx, input_sequence);
|
||||||
|
|
||||||
|
let len_int = input_sequence.len();
|
||||||
let len = sizet_model.constant(generator, ctx.ctx, len_int as u64);
|
let len = sizet_model.constant(generator, ctx.ctx, len_int as u64);
|
||||||
|
|
||||||
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
|
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
|
||||||
|
|
||||||
for i in 0..len_int {
|
for i in 0..len_int {
|
||||||
// Get the i-th element off of the tuple and load it into `result`.
|
// Get the i-th element off of the tuple and load it into `result`.
|
||||||
let int = ctx
|
let int = input_sequence.get(ctx, i, "dim").value.into_int_value();
|
||||||
.builder
|
|
||||||
.build_extract_value(input_sequence, i as u32, "int")
|
|
||||||
.unwrap()
|
|
||||||
.into_int_value();
|
|
||||||
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
|
let int = sizet_model.s_extend_or_bit_cast(generator, ctx, int, "int");
|
||||||
|
|
||||||
let offset = sizet_model.constant(generator, ctx.ctx, i as u64);
|
let offset = sizet_model.constant(generator, ctx.ctx, i as u64);
|
||||||
|
@ -92,7 +90,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
if *obj_id == ctx.primitives.int32.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
|
||||||
let input_int = input_sequence.into_int_value();
|
let input_int = input_sequence.value.into_int_value();
|
||||||
|
|
||||||
let len = sizet_model.const_1(generator, ctx.ctx);
|
let len = sizet_model.const_1(generator, ctx.ctx);
|
||||||
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
|
let result = sizet_model.array_alloca(generator, ctx, len.value, "int_sequence");
|
||||||
|
@ -106,7 +104,7 @@ pub fn parse_numpy_int_sequence<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
}
|
}
|
||||||
_ => panic!(
|
_ => panic!(
|
||||||
"encountered unknown sequence type: {}",
|
"encountered unknown sequence type: {}",
|
||||||
ctx.unifier.stringify(input_sequence_ty)
|
ctx.unifier.stringify(input_sequence.ty)
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,34 +1,68 @@
|
||||||
use inkwell::values::{BasicValueEnum, StructValue};
|
use core::panic;
|
||||||
|
|
||||||
|
use inkwell::values::StructValue;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{CodeGenContext, CodeGenerator},
|
codegen::{CodeGenContext, CodeGenerator},
|
||||||
typecheck::typedef::Type,
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::AnyObject;
|
||||||
|
|
||||||
|
/// A NAC3 tuple object.
|
||||||
|
///
|
||||||
|
/// NOTE: This struct has no copy trait.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct TupleObject<'ctx> {
|
pub struct TupleObject<'ctx> {
|
||||||
|
/// The type of the tuple.
|
||||||
pub tys: Vec<Type>,
|
pub tys: Vec<Type>,
|
||||||
|
/// The underlying LLVM value of this tuple.
|
||||||
pub value: StructValue<'ctx>,
|
pub value: StructValue<'ctx>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx> TupleObject<'ctx> {
|
impl<'ctx> TupleObject<'ctx> {
|
||||||
|
// NOTE: There is no Model abstraction for Tuples. Everything has to be done raw with Inkwell.
|
||||||
|
|
||||||
|
pub fn from_object(ctx: &mut CodeGenContext<'ctx, '_>, object: AnyObject<'ctx>) -> Self {
|
||||||
|
// TODO: Keep `is_vararg_ctx` from TTuple?
|
||||||
|
let TypeEnum::TTuple { ty: tys, .. } = &*ctx.unifier.get_ty(object.ty) else {
|
||||||
|
panic!(
|
||||||
|
"Expected type to be a TypeEnum::TTuple, got {}",
|
||||||
|
ctx.unifier.stringify(object.ty)
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
let value = object.value.into_struct_value();
|
||||||
|
if value.get_type().count_fields() as usize != tys.len() {
|
||||||
|
panic!(
|
||||||
|
"Tuple type has {} item(s), but the LLVM struct value has {} field(s)",
|
||||||
|
tys.len(),
|
||||||
|
value.get_type().count_fields()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
TupleObject { tys: tys.clone(), value }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience function. Create a [`TupleObject`] from an iterator of objects.
|
||||||
pub fn create<I, G: CodeGenerator + ?Sized>(
|
pub fn create<I, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
items: I,
|
objects: I,
|
||||||
name: &str,
|
name: &str,
|
||||||
) -> Self
|
) -> Self
|
||||||
where
|
where
|
||||||
I: IntoIterator<Item = (BasicValueEnum<'ctx>, Type)>,
|
I: IntoIterator<Item = AnyObject<'ctx>>,
|
||||||
{
|
{
|
||||||
let (vals, tys): (Vec<_>, Vec<_>) = items.into_iter().unzip();
|
let (values, tys): (Vec<_>, Vec<_>) =
|
||||||
|
objects.into_iter().map(|object| (object.value, object.ty)).unzip();
|
||||||
|
|
||||||
// let tuple_ty = ctx.unifier.add_ty(TypeEnum::TTuple { ty: tys });
|
|
||||||
let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec();
|
let llvm_tys = tys.iter().map(|ty| ctx.get_llvm_type(generator, *ty)).collect_vec();
|
||||||
let llvm_tuple_ty = ctx.ctx.struct_type(&llvm_tys, false);
|
let llvm_tuple_ty = ctx.ctx.struct_type(&llvm_tys, false);
|
||||||
|
|
||||||
let pllvm_tuple = ctx.builder.build_alloca(llvm_tuple_ty, "tuple").unwrap();
|
let pllvm_tuple = ctx.builder.build_alloca(llvm_tuple_ty, "tuple").unwrap();
|
||||||
for (i, val) in vals.into_iter().enumerate() {
|
for (i, val) in values.into_iter().enumerate() {
|
||||||
// Store the dim value into the tuple
|
|
||||||
let pval = ctx.builder.build_struct_gep(pllvm_tuple, i as u32, "value").unwrap();
|
let pval = ctx.builder.build_struct_gep(pllvm_tuple, i as u32, "value").unwrap();
|
||||||
ctx.builder.build_store(pval, val).unwrap();
|
ctx.builder.build_store(pval, val).unwrap();
|
||||||
}
|
}
|
||||||
|
@ -36,4 +70,22 @@ impl<'ctx> TupleObject<'ctx> {
|
||||||
let value = ctx.builder.build_load(pllvm_tuple, name).unwrap().into_struct_value();
|
let value = ctx.builder.build_load(pllvm_tuple, name).unwrap().into_struct_value();
|
||||||
TupleObject { tys, value }
|
TupleObject { tys, value }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the `len()` of this tuple.
|
||||||
|
///
|
||||||
|
/// We statically know the lengths of tuples in NAC3.
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.tys.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the `i`-th (0-based) object in this tuple.
|
||||||
|
pub fn get(&self, ctx: &mut CodeGenContext<'ctx, '_>, i: usize, name: &str) -> AnyObject<'ctx> {
|
||||||
|
if i >= self.len() {
|
||||||
|
panic!("Tuple object with length {} have index {i}", self.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
let value = ctx.builder.build_extract_value(self.value, i as u32, name).unwrap();
|
||||||
|
let ty = self.tys[i];
|
||||||
|
AnyObject { value, ty }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ use super::model::*;
|
||||||
use super::object::ndarray::indexing::util::gen_ndarray_subscript_ndindexes;
|
use super::object::ndarray::indexing::util::gen_ndarray_subscript_ndindexes;
|
||||||
use super::object::ndarray::scalar::split_scalar_or_ndarray;
|
use super::object::ndarray::scalar::split_scalar_or_ndarray;
|
||||||
use super::object::ndarray::NDArrayObject;
|
use super::object::ndarray::NDArrayObject;
|
||||||
|
use super::object::AnyObject;
|
||||||
use super::{
|
use super::{
|
||||||
super::symbol_resolver::ValueEnum,
|
super::symbol_resolver::ValueEnum,
|
||||||
expr::destructure_range,
|
expr::destructure_range,
|
||||||
|
@ -411,13 +412,14 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||||
.gen_expr(ctx, target)?
|
.gen_expr(ctx, target)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, target_ty)?;
|
.to_basic_value_enum(ctx, generator, target_ty)?;
|
||||||
let target = NDArrayObject::from_value_and_type(generator, ctx, target, target_ty);
|
let target = AnyObject { value: target, ty: target_ty };
|
||||||
|
|
||||||
// Process key
|
// Process key
|
||||||
let key = gen_ndarray_subscript_ndindexes(generator, ctx, key)?;
|
let key = gen_ndarray_subscript_ndindexes(generator, ctx, key)?;
|
||||||
|
|
||||||
// Process value
|
// Process value
|
||||||
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
|
let value = value.to_basic_value_enum(ctx, generator, value_ty)?;
|
||||||
|
let value = AnyObject { value, ty: value_ty };
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Reference code:
|
Reference code:
|
||||||
|
@ -433,9 +435,10 @@ pub fn gen_setitem<'ctx, G: CodeGenerator>(
|
||||||
...and finally copy 1-1 from value to target.
|
...and finally copy 1-1 from value to target.
|
||||||
```
|
```
|
||||||
*/
|
*/
|
||||||
|
let target = NDArrayObject::from_object(generator, ctx, target);
|
||||||
let target = target.index(generator, ctx, &key, "assign_target_ndarray");
|
let target = target.index(generator, ctx, &key, "assign_target_ndarray");
|
||||||
let value =
|
|
||||||
split_scalar_or_ndarray(generator, ctx, value, value_ty).as_ndarray(generator, ctx);
|
let value = split_scalar_or_ndarray(generator, ctx, value).as_ndarray(generator, ctx);
|
||||||
|
|
||||||
let broadcast_result = NDArrayObject::broadcast(generator, ctx, &[target, value]);
|
let broadcast_result = NDArrayObject::broadcast(generator, ctx, &[target, value]);
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,12 @@ use crate::{
|
||||||
extern_fns, irrt, llvm_intrinsics,
|
extern_fns, irrt, llvm_intrinsics,
|
||||||
numpy::*,
|
numpy::*,
|
||||||
numpy_new::{self, gen_ndarray_transpose},
|
numpy_new::{self, gen_ndarray_transpose},
|
||||||
object::ndarray::{
|
object::{
|
||||||
functions::{FloorOrCeil, MinOrMax},
|
ndarray::{
|
||||||
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
|
functions::{FloorOrCeil, MinOrMax},
|
||||||
|
scalar::{split_scalar_or_ndarray, ScalarObject, ScalarOrNDArray},
|
||||||
|
},
|
||||||
|
AnyObject,
|
||||||
},
|
},
|
||||||
stmt::exn_constructor,
|
stmt::exn_constructor,
|
||||||
},
|
},
|
||||||
|
@ -1080,6 +1083,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
move |ctx, _, fun, args, generator| {
|
move |ctx, _, fun, args, generator| {
|
||||||
let arg_ty = fun.0.args[0].ty;
|
let arg_ty = fun.0.args[0].ty;
|
||||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
|
let arg = AnyObject { ty: arg_ty, value: arg };
|
||||||
|
|
||||||
let ret_dtype = match prim {
|
let ret_dtype = match prim {
|
||||||
PrimDef::FunInt32 => ctx.primitives.int32,
|
PrimDef::FunInt32 => ctx.primitives.int32,
|
||||||
|
@ -1091,7 +1095,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ret_dtype,
|
ret_dtype,
|
||||||
|
@ -1158,10 +1162,11 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let arg_ty = fun.0.args[0].ty;
|
let arg_ty = fun.0.args[0].ty;
|
||||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
|
let arg = AnyObject { ty: arg_ty, value: arg };
|
||||||
|
|
||||||
let ret_int_dtype = size_variant.of_int(&ctx.primitives);
|
let ret_int_dtype = size_variant.of_int(&ctx.primitives);
|
||||||
|
|
||||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ret_int_dtype,
|
ret_int_dtype,
|
||||||
|
@ -1225,8 +1230,9 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let arg_ty = fun.0.args[0].ty;
|
let arg_ty = fun.0.args[0].ty;
|
||||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
|
let arg = AnyObject { ty: arg_ty, value: arg };
|
||||||
|
|
||||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
int_sized,
|
int_sized,
|
||||||
|
@ -1618,6 +1624,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let arg_ty = fun.0.args[0].ty;
|
let arg_ty = fun.0.args[0].ty;
|
||||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
|
let arg = AnyObject { ty: arg_ty, value: arg };
|
||||||
|
|
||||||
let kind = match prim {
|
let kind = match prim {
|
||||||
PrimDef::FunNpFloor => FloorOrCeil::Floor,
|
PrimDef::FunNpFloor => FloorOrCeil::Floor,
|
||||||
|
@ -1625,7 +1632,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
|
@ -1652,8 +1659,9 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
Box::new(|ctx, _, fun, args, generator| {
|
Box::new(|ctx, _, fun, args, generator| {
|
||||||
let arg_ty = fun.0.args[0].ty;
|
let arg_ty = fun.0.args[0].ty;
|
||||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
|
let arg = AnyObject { ty: arg_ty, value: arg };
|
||||||
|
|
||||||
let result = split_scalar_or_ndarray(generator, ctx, arg, arg_ty).map(
|
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
|
@ -1790,8 +1798,9 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let a_ty = fun.0.args[0].ty;
|
let a_ty = fun.0.args[0].ty;
|
||||||
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
|
let a = args[0].1.clone().to_basic_value_enum(ctx, generator, a_ty)?;
|
||||||
|
let a = AnyObject { ty: a_ty, value: a };
|
||||||
|
|
||||||
let a = split_scalar_or_ndarray(generator, ctx, a, a_ty).as_ndarray(generator, ctx);
|
let a = split_scalar_or_ndarray(generator, ctx, a).as_ndarray(generator, ctx);
|
||||||
let result = match prim {
|
let result = match prim {
|
||||||
PrimDef::FunNpArgmin => a
|
PrimDef::FunNpArgmin => a
|
||||||
.argmin_or_argmax(generator, ctx, MinOrMax::Min)
|
.argmin_or_argmax(generator, ctx, MinOrMax::Min)
|
||||||
|
@ -1845,9 +1854,12 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||||
move |ctx, _, fun, args, generator| {
|
move |ctx, _, fun, args, generator| {
|
||||||
let x1_ty = fun.0.args[0].ty;
|
let x1_ty = fun.0.args[0].ty;
|
||||||
let x2_ty = fun.0.args[1].ty;
|
|
||||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
let x1 = AnyObject { ty: x1_ty, value: x1_val };
|
||||||
|
|
||||||
|
let x2_ty = fun.0.args[1].ty;
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
let x2 = AnyObject { ty: x2_ty, value: x2_val };
|
||||||
|
|
||||||
let kind = match prim {
|
let kind = match prim {
|
||||||
PrimDef::FunNpMinimum => MinOrMax::Min,
|
PrimDef::FunNpMinimum => MinOrMax::Min,
|
||||||
|
@ -1855,8 +1867,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let x1 = split_scalar_or_ndarray(generator, ctx, x1_val, x1_ty);
|
let x1 = split_scalar_or_ndarray(generator, ctx, x1);
|
||||||
let x2 = split_scalar_or_ndarray(generator, ctx, x2_val, x2_ty);
|
let x2 = split_scalar_or_ndarray(generator, ctx, x2);
|
||||||
|
|
||||||
// NOTE: x1.dtype() and x2.dtype() should be the same
|
// NOTE: x1.dtype() and x2.dtype() should be the same
|
||||||
let common_ty = x1.dtype();
|
let common_ty = x1.dtype();
|
||||||
|
@ -1907,8 +1919,9 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
move |ctx, _, fun, args, generator| {
|
move |ctx, _, fun, args, generator| {
|
||||||
let n_ty = fun.0.args[0].ty;
|
let n_ty = fun.0.args[0].ty;
|
||||||
let n_val = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
|
let n_val = args[0].1.clone().to_basic_value_enum(ctx, generator, n_ty)?;
|
||||||
|
let n = AnyObject { ty: n_ty, value: n_val };
|
||||||
|
|
||||||
let result = split_scalar_or_ndarray(generator, ctx, n_val, n_ty).map(
|
let result = split_scalar_or_ndarray(generator, ctx, n).map(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
num_ty.ty,
|
num_ty.ty,
|
||||||
|
@ -1936,6 +1949,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let x_ty = fun.0.args[0].ty;
|
let x_ty = fun.0.args[0].ty;
|
||||||
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
let x_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x_ty)?;
|
||||||
|
let x = AnyObject { value: x_val, ty: x_ty };
|
||||||
|
|
||||||
let function = match prim {
|
let function = match prim {
|
||||||
PrimDef::FunNpIsInf => irrt::call_isnan,
|
PrimDef::FunNpIsInf => irrt::call_isnan,
|
||||||
|
@ -1943,7 +1957,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = split_scalar_or_ndarray(generator, ctx, x_val, x_ty).map(
|
let result = split_scalar_or_ndarray(generator, ctx, x).map(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.bool,
|
ctx.primitives.bool,
|
||||||
|
@ -2010,8 +2024,9 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let arg_ty = fun.0.args[0].ty;
|
let arg_ty = fun.0.args[0].ty;
|
||||||
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
|
let arg = AnyObject { ty: arg_ty, value: arg_val };
|
||||||
|
|
||||||
let result = split_scalar_or_ndarray(generator, ctx, arg_val, arg_ty).map(
|
let result = split_scalar_or_ndarray(generator, ctx, arg).map(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.float,
|
ctx.primitives.float,
|
||||||
|
@ -2126,12 +2141,14 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
move |ctx, _, fun, args, generator| {
|
move |ctx, _, fun, args, generator| {
|
||||||
let x1_ty = fun.0.args[0].ty;
|
let x1_ty = fun.0.args[0].ty;
|
||||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
let x1 = AnyObject { ty: x1_ty, value: x1_val };
|
||||||
|
|
||||||
let x2_ty = fun.0.args[1].ty;
|
let x2_ty = fun.0.args[1].ty;
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
let x2 = AnyObject { ty: x2_ty, value: x2_val };
|
||||||
|
|
||||||
let x1 = split_scalar_or_ndarray(generator, ctx, x1_val, x1_ty);
|
let x1 = split_scalar_or_ndarray(generator, ctx, x1);
|
||||||
let x2 = split_scalar_or_ndarray(generator, ctx, x2_val, x2_ty);
|
let x2 = split_scalar_or_ndarray(generator, ctx, x2);
|
||||||
|
|
||||||
let result = ScalarOrNDArray::broadcasting_starmap(
|
let result = ScalarOrNDArray::broadcasting_starmap(
|
||||||
generator,
|
generator,
|
||||||
|
|
Loading…
Reference in New Issue