1
0
forked from M-Labs/nac3

core/ndstrides: implement len(ndarray) & refactor len()

This commit is contained in:
lyken 2024-08-20 12:29:31 +08:00 committed by David Mak
parent 792374fa9a
commit 54a842a93f

View File

@ -7,16 +7,17 @@ use itertools::Itertools;
use super::{ use super::{
classes::{ classes::{
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, NDArrayValue, ProxyValue, RangeValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
}, },
expr::destructure_range, expr::destructure_range,
extern_fns, irrt, extern_fns, irrt,
irrt::calculate_len_for_slice_range, irrt::calculate_len_for_slice_range,
llvm_intrinsics, llvm_intrinsics,
macros::codegen_unreachable, macros::codegen_unreachable,
model::*,
numpy, numpy,
numpy::ndarray_elementwise_unaryop_impl, numpy::ndarray_elementwise_unaryop_impl,
object::{any::AnyObject, list::ListObject, ndarray::NDArrayObject, tuple::TupleObject},
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}; };
@ -42,58 +43,33 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>), n: (Type, BasicValueEnum<'ctx>),
) -> Result<IntValue<'ctx>, String> { ) -> Result<IntValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let range_ty = ctx.primitives.range;
let (arg_ty, arg) = n; let (arg_ty, arg) = n;
Ok(if ctx.unifier.unioned(arg_ty, ctx.primitives.range) {
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range"));
let (start, end, step) = destructure_range(ctx, arg); let (start, end, step) = destructure_range(ctx, arg);
calculate_len_for_slice_range(generator, ctx, start, end, step) calculate_len_for_slice_range(generator, ctx, start, end, step)
} else { } else {
match &*ctx.unifier.get_ty_immutable(arg_ty) { let arg = AnyObject { ty: arg_ty, value: arg };
TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false), let len: Instance<'ctx, Int<Int32>> = match &*ctx.unifier.get_ty(arg_ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { TypeEnum::TTuple { .. } => {
let zero = llvm_i32.const_zero(); let tuple = TupleObject::from_object(ctx, arg);
let len = ctx tuple.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32)
.build_gep_and_load(
arg.into_pointer_value(),
&[zero, llvm_i32.const_int(1, false)],
None,
)
.into_int_value();
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. }
let llvm_usize = generator.get_size_type(ctx.ctx); if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
{
let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); let ndarray = NDArrayObject::from_object(generator, ctx, arg);
ndarray.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32)
let ndims = arg.dim_sizes().size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "")
.unwrap(),
"0:TypeError",
"len() of unsized object",
[None, None, None],
ctx.current_loc,
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
};
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
} }
_ => codegen_unreachable!(ctx), TypeEnum::TObj { obj_id, .. }
} if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let list = ListObject::from_object(generator, ctx, arg);
list.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32)
}
_ => unsupported_type(ctx, "len", &[arg_ty]),
};
len.value
}) })
} }