forked from M-Labs/nac3
1
0
Fork 0

core/ndstrides: implement ndarray len()

This commit is contained in:
lyken 2024-07-26 15:50:22 +08:00
parent adb43958d0
commit e14eba05d2
1 changed files with 9 additions and 45 deletions

View File

@ -9,16 +9,19 @@ use inkwell::{
IntPredicate, IntPredicate,
}; };
use itertools::Either; use itertools::Either;
use ndarray::basic::call_nac3_ndarray_len;
use strum::IntoEnumIterator; use strum::IntoEnumIterator;
use crate::{ use crate::{
codegen::{ codegen::{
builtin_fns, builtin_fns,
classes::{ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor}, classes::{ProxyValue, RangeValue},
expr::destructure_range, expr::destructure_range,
irrt::*, irrt::*,
model::*,
numpy::*, numpy::*,
stmt::exn_constructor, stmt::exn_constructor,
structs::ndarray::NpArray,
}, },
symbol_resolver::SymbolValue, symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
@ -1464,51 +1467,12 @@ impl<'a> BuiltinBuilder<'a> {
} }
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i32 = ctx.ctx.i32_type(); let sizet = generator.get_sizet(ctx.ctx);
let llvm_usize = generator.get_size_type(ctx.ctx); let pndarray_model = PointerModel(StructModel(NpArray { sizet }));
let arg = NDArrayValue::from_ptr_val( let ndarray = pndarray_model.review_value(ctx.ctx, arg).unwrap();
arg.into_pointer_value(), let len = call_nac3_ndarray_len(generator, ctx, ndarray);
llvm_usize, Some(len.value.as_basic_value_enum())
None,
);
let ndims = arg.dim_sizes().size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::NE,
ndims,
llvm_usize.const_zero(),
"",
)
.unwrap(),
"0:TypeError",
&format!("{name}() of unsized object", name = prim.name()),
[None, None, None],
ctx.current_loc,
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
};
if len.get_type().get_bit_width() == 32 {
Some(len.into())
} else {
Some(
ctx.builder
.build_int_truncate(len, llvm_i32, "len")
.map(Into::into)
.unwrap(),
)
}
} }
_ => unreachable!(), _ => unreachable!(),
} }