forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: implement & refactor call_len()

This commit is contained in:
lyken 2024-08-20 12:29:31 +08:00
parent 1c317f9205
commit 5411ac5c88
No known key found for this signature in database
GPG Key ID: 3BD5FC6AC8325DD8
1 changed files with 23 additions and 50 deletions

View File

@ -4,8 +4,7 @@ use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
use itertools::Itertools; use itertools::Itertools;
use crate::codegen::classes::{ use crate::codegen::classes::{
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, NDArrayValue, ProxyValue, RangeValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
}; };
use crate::codegen::expr::destructure_range; use crate::codegen::expr::destructure_range;
use crate::codegen::irrt::calculate_len_for_slice_range; use crate::codegen::irrt::calculate_len_for_slice_range;
@ -14,7 +13,13 @@ use crate::codegen::stmt::gen_for_callback_incrementing;
use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator};
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{Type, TypeEnum}; use crate::typecheck::typedef::Type;
use super::model::*;
use super::object::any::AnyObject;
use super::object::list::ListObject;
use super::object::ndarray::NDArrayObject;
use super::object::tuple::TupleObject;
/// Shorthand for [`unreachable!()`] when a type of argument is not supported. /// Shorthand for [`unreachable!()`] when a type of argument is not supported.
/// ///
@ -32,58 +37,26 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
n: (Type, BasicValueEnum<'ctx>), n: (Type, BasicValueEnum<'ctx>),
) -> Result<IntValue<'ctx>, String> { ) -> Result<IntValue<'ctx>, String> {
let llvm_i32 = ctx.ctx.i32_type();
let range_ty = ctx.primitives.range;
let (arg_ty, arg) = n; let (arg_ty, arg) = n;
Ok(if ctx.unifier.unioned(arg_ty, ctx.primitives.range) {
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range"));
let (start, end, step) = destructure_range(ctx, arg); let (start, end, step) = destructure_range(ctx, arg);
calculate_len_for_slice_range(generator, ctx, start, end, step) calculate_len_for_slice_range(generator, ctx, start, end, step)
} else { } else {
match &*ctx.unifier.get_ty_immutable(arg_ty) { let arg = AnyObject { ty: arg_ty, value: arg };
TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false), let len = if NDArrayObject::is_instance(ctx, arg) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { let ndarray = NDArrayObject::from_object(generator, ctx, arg);
let zero = llvm_i32.const_zero(); ndarray.len(generator, ctx)
let len = ctx } else if TupleObject::is_instance(ctx, arg) {
.build_gep_and_load( let tuple = TupleObject::from_object(ctx, arg);
arg.into_pointer_value(), tuple.len(generator, ctx)
&[zero, llvm_i32.const_int(1, false)], } else if ListObject::is_instance(ctx, arg) {
None, let list = ListObject::from_object(generator, ctx, arg);
) list.len(generator, ctx)
.into_int_value(); } else {
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap() unsupported_type(ctx, "len", &[arg_ty])
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_usize = generator.get_size_type(ctx.ctx);
let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
let ndims = arg.dim_sizes().size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "")
.unwrap(),
"0:TypeError",
"len() of unsized object",
[None, None, None],
ctx.current_loc,
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
}; };
len.truncate(generator, ctx, Int32).value
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
}
_ => unreachable!(),
}
}) })
} }