From e0d1e9a007de7b48f0552b5c7f0a164d1e5bf91f Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 12:29:31 +0800 Subject: [PATCH] core/ndstrides: implement len(ndarray) & refactor len() --- nac3core/src/codegen/builtin_fns.rs | 74 +++++++++++------------------ 1 file changed, 27 insertions(+), 47 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 9914d81c..3ac65237 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -4,8 +4,7 @@ use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use itertools::Itertools; use crate::codegen::classes::{ - ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + NDArrayValue, ProxyValue, RangeValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; use crate::codegen::expr::destructure_range; use crate::codegen::irrt::calculate_len_for_slice_range; @@ -16,6 +15,12 @@ use crate::toplevel::helper::PrimDef; use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::typecheck::typedef::{Type, TypeEnum}; +use super::model::*; +use super::object::any::AnyObject; +use super::object::list::ListObject; +use super::object::ndarray::NDArrayObject; +use super::object::tuple::TupleObject; + /// Shorthand for [`unreachable!()`] when a type of argument is not supported. /// /// The generated message will contain the function name and the name of the unsupported type. @@ -32,58 +37,33 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let range_ty = ctx.primitives.range; let (arg_ty, arg) = n; - - Ok(if ctx.unifier.unioned(arg_ty, range_ty) { + Ok(if ctx.unifier.unioned(arg_ty, ctx.primitives.range) { let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); calculate_len_for_slice_range(generator, ctx, start, end, step) } else { - match &*ctx.unifier.get_ty_immutable(arg_ty) { - TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false), - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { - let zero = llvm_i32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, llvm_i32.const_int(1, false)], - None, - ) - .into_int_value(); - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + let arg = AnyObject { ty: arg_ty, value: arg }; + let len: Instance<'ctx, Int> = match &*ctx.unifier.get_ty(arg_ty) { + TypeEnum::TTuple { .. } => { + let tuple = TupleObject::from_object(ctx, arg); + tuple.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); - - let ndims = arg.dim_sizes().size(ctx, generator); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "") - .unwrap(), - "0:TypeError", - "len() of unsized object", - [None, None, None], - ctx.current_loc, - ); - - let len = unsafe { - arg.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - }; - - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayObject::from_object(generator, ctx, arg); + ndarray.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32) } - _ => unreachable!(), - } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListObject::from_object(generator, ctx, arg); + list.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32) + } + _ => unsupported_type(ctx, "len", &[arg_ty]), + }; + len.value }) }