[core] codegen: Reimplement builtin funcs to support strided ndarrays

Based on 7f3c4530: core/ndstrides: update builtin_fns to use ndarray
with strides
This commit is contained in:
David Mak 2024-12-19 11:24:28 +08:00
parent 61ba5fa2e5
commit 90c19bed16
3 changed files with 376 additions and 610 deletions

File diff suppressed because it is too large Load Diff

View File

@ -61,12 +61,7 @@ impl<'ctx> NDArrayType<'ctx> {
// Use an existing ndarray. // Use an existing ndarray.
// Check that its shape is compatible with the broadcast shape. // Check that its shape is compatible with the broadcast shape.
result_ndarray.assert_can_be_written_by_out( result_ndarray.assert_can_be_written_by_out(generator, ctx, broadcast_result.shape);
generator,
ctx,
broadcast_result.ndims,
broadcast_result.shape,
);
result_ndarray result_ndarray
} }
}; };

View File

@ -8,9 +8,9 @@ use inkwell::{
use itertools::Itertools; use itertools::Itertools;
use super::{ use super::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ProxyValue, TupleValue, ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TupleValue, TypedArrayLikeAccessor,
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor,
UntypedArrayLikeAccessor, UntypedArrayLikeMutator, UntypedArrayLikeMutator,
}; };
use crate::{ use crate::{
codegen::{ codegen::{
@ -531,24 +531,18 @@ impl<'ctx> NDArrayValue<'ctx> {
&self, &self,
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
out_ndims: u64, out_shape: impl TypedArrayLikeAccessor<'ctx, G, IntValue<'ctx>>,
out_shape: ArraySliceValue<'ctx>,
) { ) {
assert!(self.ndims.is_some(), "NDArrayValue::assert_can_be_written_by_out can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))"); assert!(self.ndims.is_some(), "NDArrayValue::assert_can_be_written_by_out can only be called on an instance with compile-time known ndims (self.ndims = Some(ndims))");
let ndarray_ndims = self.llvm_usize.const_int(self.ndims.unwrap(), false); let ndarray_shape = self.shape();
let ndarray_shape = self.shape().base_ptr(ctx, generator);
let output_ndims = self.llvm_usize.const_int(out_ndims, false);
let output_shape = out_shape; let output_shape = out_shape;
irrt::ndarray::call_nac3_ndarray_util_assert_output_shape_same( irrt::ndarray::call_nac3_ndarray_util_assert_output_shape_same(
generator, generator,
ctx, ctx,
ndarray_ndims, &ndarray_shape,
ndarray_shape, &output_shape,
output_ndims,
output_shape.base_ptr(ctx, generator),
); );
} }
} }