1
0
forked from M-Labs/nac3

core/ndstrides: implement ndarray len()

This commit is contained in:
lyken 2024-07-26 15:50:22 +08:00
parent 02e3ddfce6
commit 5b9ac9b09c

View File

@ -9,16 +9,19 @@ use inkwell::{
IntPredicate,
};
use itertools::Either;
use ndarray::basic::call_nac3_ndarray_len;
use strum::IntoEnumIterator;
use crate::{
codegen::{
builtin_fns,
classes::{ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor},
classes::{ProxyValue, RangeValue},
expr::destructure_range,
irrt::*,
model::*,
numpy::*,
stmt::exn_constructor,
structure::ndarray::NpArray,
},
symbol_resolver::SymbolValue,
toplevel::{helper::PrimDef, numpy::make_ndarray_ty},
@ -1464,51 +1467,13 @@ impl<'a> BuiltinBuilder<'a> {
}
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let tyctx = generator.type_context(ctx.ctx);
let pndarray_model = PtrModel(StructModel(NpArray));
let arg = NDArrayValue::from_ptr_val(
arg.into_pointer_value(),
llvm_usize,
None,
);
let ndims = arg.dim_sizes().size(ctx, generator);
ctx.make_assert(
generator,
ctx.builder
.build_int_compare(
IntPredicate::NE,
ndims,
llvm_usize.const_zero(),
"",
)
.unwrap(),
"0:TypeError",
&format!("{name}() of unsized object", name = prim.name()),
[None, None, None],
ctx.current_loc,
);
let len = unsafe {
arg.dim_sizes().get_typed_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None,
)
};
if len.get_type().get_bit_width() == 32 {
Some(len.into())
} else {
Some(
ctx.builder
.build_int_truncate(len, llvm_i32, "len")
.map(Into::into)
.unwrap(),
)
}
let ndarray =
pndarray_model.check_value(tyctx, ctx.ctx, arg).unwrap();
let len = call_nac3_ndarray_len(generator, ctx, ndarray);
Some(len.value.as_basic_value_enum())
}
_ => unreachable!(),
}