forked from M-Labs/nac3
[core] codegen: Refactor len()
Based on 54a842a9
: core/ndstrides: implement len(ndarray) & refactor
len()
This commit is contained in:
parent
5880f964bb
commit
26f1428739
@ -14,9 +14,9 @@ use super::{
|
||||
numpy,
|
||||
numpy::ndarray_elementwise_unaryop_impl,
|
||||
stmt::gen_for_callback_incrementing,
|
||||
types::{ndarray::NDArrayType, TupleType},
|
||||
types::{ndarray::NDArrayType, ListType, TupleType},
|
||||
values::{
|
||||
ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
||||
ndarray::NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
||||
UntypedArrayLikeAccessor,
|
||||
},
|
||||
CodeGenContext, CodeGenerator,
|
||||
@ -55,42 +55,33 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||
calculate_len_for_slice_range(generator, ctx, start, end, step)
|
||||
} else {
|
||||
match &*ctx.unifier.get_ty_immutable(arg_ty) {
|
||||
TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false),
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
|
||||
let zero = llvm_i32.const_zero();
|
||||
let len = ctx
|
||||
.build_gep_and_load(
|
||||
arg.into_pointer_value(),
|
||||
&[zero, llvm_i32.const_int(1, false)],
|
||||
None,
|
||||
)
|
||||
.into_int_value();
|
||||
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
|
||||
TypeEnum::TTuple { .. } => {
|
||||
let tuple = TupleType::from_unifier_type(generator, ctx, arg_ty)
|
||||
.map_value(arg.into_struct_value(), None);
|
||||
llvm_i32.const_int(tuple.get_type().num_elements().into(), false)
|
||||
}
|
||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let arg = NDArrayType::from_unifier_type(generator, ctx, arg_ty)
|
||||
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty)
|
||||
.map_value(arg.into_pointer_value(), None);
|
||||
|
||||
let ndims = arg.shape().size(ctx, generator);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder
|
||||
.build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "")
|
||||
.unwrap(),
|
||||
"0:TypeError",
|
||||
"len() of unsized object",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
let len = unsafe {
|
||||
arg.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
||||
};
|
||||
|
||||
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
|
||||
ctx.builder
|
||||
.build_int_truncate_or_bit_cast(ndarray.len(generator, ctx), llvm_i32, "len")
|
||||
.unwrap()
|
||||
}
|
||||
_ => codegen_unreachable!(ctx),
|
||||
|
||||
TypeEnum::TObj { obj_id, .. }
|
||||
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||
{
|
||||
let list = ListType::from_unifier_type(generator, ctx, arg_ty)
|
||||
.map_value(arg.into_pointer_value(), None);
|
||||
ctx.builder
|
||||
.build_int_truncate_or_bit_cast(list.load_size(ctx, None), llvm_i32, "len")
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
_ => unsupported_type(ctx, "len", &[arg_ty]),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user