forked from M-Labs/nac3
core/ndstrides: implement len(ndarray) & refactor len()
This commit is contained in:
parent
3a241acc9c
commit
ad5afb52c4
|
@ -4,8 +4,7 @@ use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::classes::{
|
use crate::codegen::classes::{
|
||||||
ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor,
|
NDArrayValue, ProxyValue, RangeValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
|
||||||
};
|
};
|
||||||
use crate::codegen::expr::destructure_range;
|
use crate::codegen::expr::destructure_range;
|
||||||
use crate::codegen::irrt::calculate_len_for_slice_range;
|
use crate::codegen::irrt::calculate_len_for_slice_range;
|
||||||
|
@ -16,6 +15,12 @@ use crate::toplevel::helper::PrimDef;
|
||||||
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||||
use crate::typecheck::typedef::{Type, TypeEnum};
|
use crate::typecheck::typedef::{Type, TypeEnum};
|
||||||
|
|
||||||
|
use super::model::*;
|
||||||
|
use super::object::any::AnyObject;
|
||||||
|
use super::object::list::ListObject;
|
||||||
|
use super::object::ndarray::NDArrayObject;
|
||||||
|
use super::object::tuple::TupleObject;
|
||||||
|
|
||||||
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
|
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
|
||||||
///
|
///
|
||||||
/// The generated message will contain the function name and the name of the unsupported type.
|
/// The generated message will contain the function name and the name of the unsupported type.
|
||||||
|
@ -32,58 +37,33 @@ pub fn call_len<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
n: (Type, BasicValueEnum<'ctx>),
|
n: (Type, BasicValueEnum<'ctx>),
|
||||||
) -> Result<IntValue<'ctx>, String> {
|
) -> Result<IntValue<'ctx>, String> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let range_ty = ctx.primitives.range;
|
|
||||||
let (arg_ty, arg) = n;
|
let (arg_ty, arg) = n;
|
||||||
|
Ok(if ctx.unifier.unioned(arg_ty, ctx.primitives.range) {
|
||||||
Ok(if ctx.unifier.unioned(arg_ty, range_ty) {
|
|
||||||
let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range"));
|
let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range"));
|
||||||
let (start, end, step) = destructure_range(ctx, arg);
|
let (start, end, step) = destructure_range(ctx, arg);
|
||||||
calculate_len_for_slice_range(generator, ctx, start, end, step)
|
calculate_len_for_slice_range(generator, ctx, start, end, step)
|
||||||
} else {
|
} else {
|
||||||
match &*ctx.unifier.get_ty_immutable(arg_ty) {
|
let arg = AnyObject { ty: arg_ty, value: arg };
|
||||||
TypeEnum::TTuple { ty, .. } => llvm_i32.const_int(ty.len() as u64, false),
|
let len: Instance<'ctx, Int<Int32>> = match &*ctx.unifier.get_ty(arg_ty) {
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => {
|
TypeEnum::TTuple { .. } => {
|
||||||
let zero = llvm_i32.const_zero();
|
let tuple = TupleObject::from_object(ctx, arg);
|
||||||
let len = ctx
|
tuple.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32)
|
||||||
.build_gep_and_load(
|
|
||||||
arg.into_pointer_value(),
|
|
||||||
&[zero, llvm_i32.const_int(1, false)],
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.into_int_value();
|
|
||||||
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
|
|
||||||
}
|
}
|
||||||
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TypeEnum::TObj { obj_id, .. }
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
if *obj_id == ctx.primitives.ndarray.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None);
|
let ndarray = NDArrayObject::from_object(generator, ctx, arg);
|
||||||
|
ndarray.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32)
|
||||||
let ndims = arg.dim_sizes().size(ctx, generator);
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
ctx.builder
|
|
||||||
.build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "")
|
|
||||||
.unwrap(),
|
|
||||||
"0:TypeError",
|
|
||||||
"len() of unsized object",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
|
|
||||||
let len = unsafe {
|
|
||||||
arg.dim_sizes().get_typed_unchecked(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
&llvm_usize.const_zero(),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
ctx.builder.build_int_truncate_or_bit_cast(len, llvm_i32, "len").unwrap()
|
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
TypeEnum::TObj { obj_id, .. }
|
||||||
}
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
let list = ListObject::from_object(generator, ctx, arg);
|
||||||
|
list.len(generator, ctx).truncate_or_bit_cast(generator, ctx, Int32)
|
||||||
|
}
|
||||||
|
_ => unsupported_type(ctx, "len", &[arg_ty]),
|
||||||
|
};
|
||||||
|
len.value
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue