From 9a9eeba28d73ff9fae21b234f8231418c0dcb455 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 17 Dec 2024 14:21:13 +0800 Subject: [PATCH] [core] codegen: Refactor len() Based on 54a842a9: core/ndstrides: implement len(ndarray) & refactor len() --- nac3core/src/codegen/builtin_fns.rs | 61 ++++++++++++----------------- 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 32b95a75..54650ab3 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -14,9 +14,9 @@ use super::{ numpy, numpy::ndarray_elementwise_unaryop_impl, stmt::gen_for_callback_incrementing, - types::{ndarray::NDArrayType, TupleType}, + types::{ndarray::NDArrayType, ListType, TupleType}, values::{ - ndarray::NDArrayValue, ArrayLikeValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, + ndarray::NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenerator, @@ -55,42 +55,33 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( calculate_len_for_slice_range(generator, ctx, start, end, step) } else { match &*ctx.unifier.get_ty_immutable(arg_ty) { - TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false), - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { - let zero = llvm_i32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, llvm_i32.const_int(1, false)], - None, - ) - .into_int_value(); - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + TypeEnum::TTuple { .. } => { + let tuple = TupleType::from_unifier_type(generator, ctx, arg_ty) + .map_value(arg.into_struct_value(), None); + llvm_i32.const_int(tuple.get_type().num_elements().into(), false) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let llvm_usize = generator.get_size_type(ctx.ctx); - let arg = NDArrayType::from_unifier_type(generator, ctx, arg_ty) + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayType::from_unifier_type(generator, ctx, arg_ty) .map_value(arg.into_pointer_value(), None); - - let ndims = arg.shape().size(ctx, generator); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "") - .unwrap(), - "0:TypeError", - "len() of unsized object", - [None, None, None], - ctx.current_loc, - ); - - let len = unsafe { - arg.shape().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + ctx.builder + .build_int_truncate_or_bit_cast(ndarray.len(generator, ctx), llvm_i32, "len") + .unwrap() } - _ => codegen_unreachable!(ctx), + + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListType::from_unifier_type(generator, ctx, arg_ty) + .map_value(arg.into_pointer_value(), None); + ctx.builder + .build_int_truncate_or_bit_cast(list.load_size(ctx, None), llvm_i32, "len") + .unwrap() + } + + _ => unsupported_type(ctx, "len", &[arg_ty]), } }) }