From 59bea2cd4d58272698788d0cf488f992f271dfec Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 12:29:31 +0800 Subject: [PATCH] core/ndstrides: implement len(ndarray) & refactor len() --- nac3core/src/codegen/builtin_fns.rs | 72 ++++++++++------------------- 1 file changed, 25 insertions(+), 47 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 66ac807c..91075e1d 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -5,11 +5,14 @@ use inkwell::{ }; use itertools::Itertools; +use super::{ + model::*, + object::{any::AnyObject, list::ListObject, ndarray::NDArrayObject, tuple::TupleObject}, +}; use crate::{ codegen::{ classes::{ - ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, - UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + NDArrayValue, ProxyValue, RangeValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, expr::destructure_range, extern_fns, irrt, @@ -43,58 +46,33 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( ctx: &mut CodeGenContext<'ctx, '_>, n: (Type, BasicValueEnum<'ctx>), ) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let range_ty = ctx.primitives.range; let (arg_ty, arg) = n; - - Ok(if ctx.unifier.unioned(arg_ty, range_ty) { + Ok(if ctx.unifier.unioned(arg_ty, ctx.primitives.range) { let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); let (start, end, step) = destructure_range(ctx, arg); calculate_len_for_slice_range(generator, ctx, start, end, step) } else { - match &*ctx.unifier.get_ty_immutable(arg_ty) { - TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false), - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { - let zero = llvm_i32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, llvm_i32.const_int(1, false)], - None, - ) - .into_int_value(); - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + let arg = AnyObject { ty: arg_ty, value: arg }; + let len: Instance<'ctx, Int> = match &*ctx.unifier.get_ty(arg_ty) { + TypeEnum::TTuple { .. } => { + let tuple = TupleObject::from_object(ctx, arg); + tuple.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32) } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let llvm_usize = generator.get_size_type(ctx.ctx); - - let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); - - let ndims = arg.dim_sizes().size(ctx, generator); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "") - .unwrap(), - "0:TypeError", - "len() of unsized object", - [None, None, None], - ctx.current_loc, - ); - - let len = unsafe { - arg.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - }; - - ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() => + { + let ndarray = NDArrayObject::from_object(generator, ctx, arg); + ndarray.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32) } - _ => codegen_unreachable!(ctx), - } + TypeEnum::TObj { obj_id, .. } + if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() => + { + let list = ListObject::from_object(generator, ctx, arg); + list.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32) + } + _ => unsupported_type(ctx, "len", &[arg_ty]), + }; + len.value }) }