forked from M-Labs/nac3
[core] codegen/ndarray: Use IRRT for size() and indexing operations
Also refactor some usages of call_ndarray_calc_size with ndarray.size().
This commit is contained in:
parent
3c0ce3031f
commit
dc9efa9e8c
@ -29,25 +29,6 @@ void __nac3_ndarray_calc_nd_indices_impl(SizeT index, const SizeT* dims, SizeT n
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT __nac3_ndarray_flatten_index_impl(const SizeT* dims,
|
|
||||||
SizeT num_dims,
|
|
||||||
const NDIndexInt* indices,
|
|
||||||
SizeT num_indices) {
|
|
||||||
SizeT idx = 0;
|
|
||||||
SizeT stride = 1;
|
|
||||||
for (SizeT i = 0; i < num_dims; ++i) {
|
|
||||||
SizeT ri = num_dims - i - 1;
|
|
||||||
if (ri < num_indices) {
|
|
||||||
idx += stride * indices[ri];
|
|
||||||
}
|
|
||||||
|
|
||||||
__builtin_assume(dims[i] > 0);
|
|
||||||
stride *= dims[ri];
|
|
||||||
}
|
|
||||||
return idx;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename SizeT>
|
template<typename SizeT>
|
||||||
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
|
void __nac3_ndarray_calc_broadcast_impl(const SizeT* lhs_dims,
|
||||||
SizeT lhs_ndims,
|
SizeT lhs_ndims,
|
||||||
@ -107,18 +88,6 @@ void __nac3_ndarray_calc_nd_indices64(uint64_t index, const uint64_t* dims, uint
|
|||||||
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
__nac3_ndarray_calc_nd_indices_impl(index, dims, num_dims, idxs);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t
|
|
||||||
__nac3_ndarray_flatten_index(const uint32_t* dims, uint32_t num_dims, const NDIndexInt* indices, uint32_t num_indices) {
|
|
||||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t __nac3_ndarray_flatten_index64(const uint64_t* dims,
|
|
||||||
uint64_t num_dims,
|
|
||||||
const NDIndexInt* indices,
|
|
||||||
uint64_t num_indices) {
|
|
||||||
return __nac3_ndarray_flatten_index_impl(dims, num_dims, indices, num_indices);
|
|
||||||
}
|
|
||||||
|
|
||||||
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
|
void __nac3_ndarray_calc_broadcast(const uint32_t* lhs_dims,
|
||||||
uint32_t lhs_ndims,
|
uint32_t lhs_ndims,
|
||||||
const uint32_t* rhs_dims,
|
const uint32_t* rhs_dims,
|
||||||
|
@ -877,8 +877,7 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty);
|
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, a_ty);
|
||||||
|
|
||||||
let n = llvm_ndarray_ty.map_value(n, None);
|
let n = llvm_ndarray_ty.map_value(n, None);
|
||||||
let n_sz =
|
let n_sz = n.size(generator, ctx);
|
||||||
irrt::ndarray::call_ndarray_calc_size(generator, ctx, &n.shape(), (None, None));
|
|
||||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||||
let n_sz_eqz = ctx
|
let n_sz_eqz = ctx
|
||||||
.builder
|
.builder
|
||||||
@ -913,7 +912,16 @@ pub fn call_numpy_max_min<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
llvm_int64.const_int(1, false),
|
llvm_int64.const_int(1, false),
|
||||||
(n_sz, false),
|
(n_sz, false),
|
||||||
|generator, ctx, _, idx| {
|
|generator, ctx, _, idx| {
|
||||||
let elem = unsafe { n.data().get_unchecked(ctx, generator, &idx, None) };
|
let elem = unsafe {
|
||||||
|
n.data().get_unchecked(
|
||||||
|
ctx,
|
||||||
|
generator,
|
||||||
|
&ctx.builder
|
||||||
|
.build_int_truncate_or_bit_cast(idx, llvm_usize, "")
|
||||||
|
.unwrap(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
let accumulator = ctx.builder.build_load(accumulator_addr, "").unwrap();
|
||||||
let cur_idx = ctx.builder.build_load(res_idx, "").unwrap();
|
let cur_idx = ctx.builder.build_load(res_idx, "").unwrap();
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::{BasicTypeEnum, IntType},
|
types::BasicTypeEnum,
|
||||||
values::{BasicValueEnum, CallSiteValue, IntValue},
|
values::{BasicValueEnum, CallSiteValue, IntValue},
|
||||||
AddressSpace, IntPredicate,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
@ -138,78 +138,8 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_flatten_index`. Returns a `usize` of the flattened index for
|
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
|
||||||
/// the multidimensional index.
|
/// dimension and size of each dimension of the resultant `ndarray`.
|
||||||
///
|
|
||||||
/// * `ndarray` - LLVM pointer to the `NDArray`. This value must be the LLVM representation of an
|
|
||||||
/// `NDArray`.
|
|
||||||
/// * `indices` - The multidimensional index to compute the flattened index for.
|
|
||||||
pub fn call_ndarray_flatten_index<'ctx, G, Index>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
ndarray: NDArrayValue<'ctx>,
|
|
||||||
indices: &Index,
|
|
||||||
) -> IntValue<'ctx>
|
|
||||||
where
|
|
||||||
G: CodeGenerator + ?Sized,
|
|
||||||
Index: ArrayLikeIndexer<'ctx>,
|
|
||||||
{
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
|
||||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
|
||||||
|
|
||||||
debug_assert_eq!(
|
|
||||||
IntType::try_from(indices.element_type(ctx, generator))
|
|
||||||
.map(IntType::get_bit_width)
|
|
||||||
.unwrap_or_default(),
|
|
||||||
llvm_i32.get_bit_width(),
|
|
||||||
"Expected i32 value for argument `indices` to `call_ndarray_flatten_index_impl`"
|
|
||||||
);
|
|
||||||
debug_assert_eq!(
|
|
||||||
indices.size(ctx, generator).get_type().get_bit_width(),
|
|
||||||
llvm_usize.get_bit_width(),
|
|
||||||
"Expected usize integer value for argument `indices_size` to `call_ndarray_flatten_index_impl`"
|
|
||||||
);
|
|
||||||
|
|
||||||
let ndarray_flatten_index_fn_name =
|
|
||||||
get_usize_dependent_function_name(generator, ctx, "__nac3_ndarray_flatten_index");
|
|
||||||
let ndarray_flatten_index_fn =
|
|
||||||
ctx.module.get_function(&ndarray_flatten_index_fn_name).unwrap_or_else(|| {
|
|
||||||
let fn_type = llvm_usize.fn_type(
|
|
||||||
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
ctx.module.add_function(&ndarray_flatten_index_fn_name, fn_type, None)
|
|
||||||
});
|
|
||||||
|
|
||||||
let ndarray_num_dims = ndarray.load_ndims(ctx);
|
|
||||||
let ndarray_dims = ndarray.shape();
|
|
||||||
|
|
||||||
let index = ctx
|
|
||||||
.builder
|
|
||||||
.build_call(
|
|
||||||
ndarray_flatten_index_fn,
|
|
||||||
&[
|
|
||||||
ndarray_dims.base_ptr(ctx, generator).into(),
|
|
||||||
ndarray_num_dims.into(),
|
|
||||||
indices.base_ptr(ctx, generator).into(),
|
|
||||||
indices.size(ctx, generator).into(),
|
|
||||||
],
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
index
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a [`TypedArrayLikeAdapter`]
|
|
||||||
/// containing the size of each dimension of the resultant `ndarray`.
|
|
||||||
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
@ -21,8 +21,8 @@ use super::{
|
|||||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
||||||
types::{ndarray::NDArrayType, ListType, ProxyType},
|
types::{ndarray::NDArrayType, ListType, ProxyType},
|
||||||
values::{
|
values::{
|
||||||
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue,
|
ndarray::NDArrayValue, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue,
|
||||||
TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
|
ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator,
|
||||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
},
|
},
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
@ -318,12 +318,7 @@ where
|
|||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let ndarray_num_elems = call_ndarray_calc_size(
|
let ndarray_num_elems = ndarray.size(generator, ctx);
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
&ndarray.shape().as_slice_value(ctx, generator),
|
|
||||||
(None, None),
|
|
||||||
);
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
gen_for_callback_incrementing(
|
||||||
generator,
|
generator,
|
||||||
@ -434,6 +429,66 @@ where
|
|||||||
rhs_val.get_type()
|
rhs_val.get_type()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Returns the element of an ndarray indexed by the given indices, performing int-promotion on
|
||||||
|
// `indices` where necessary.
|
||||||
|
//
|
||||||
|
// Required for compatibility with `NDArrayType::get_unchecked`.
|
||||||
|
let get_data_by_indices_compat =
|
||||||
|
|generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
ndarray: NDArrayValue<'ctx>,
|
||||||
|
indices: TypedArrayLikeAdapter<'ctx, G, IntValue<'ctx>>| {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
// Workaround: Promote lhs_idx to usize* to make the array compatible with new IRRT
|
||||||
|
let stackptr = llvm_intrinsics::call_stacksave(ctx, None);
|
||||||
|
let indices = if llvm_usize == ctx.ctx.i32_type() {
|
||||||
|
indices
|
||||||
|
} else {
|
||||||
|
let indices_usize = TypedArrayLikeAdapter::<G, IntValue<'ctx>>::from(
|
||||||
|
ArraySliceValue::from_ptr_val(
|
||||||
|
ctx.builder
|
||||||
|
.build_array_alloca(llvm_usize, indices.size(ctx, generator), "")
|
||||||
|
.unwrap(),
|
||||||
|
indices.size(ctx, generator),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
|_, _, val| val.into_int_value(),
|
||||||
|
|_, _, val| val.into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(indices.size(ctx, generator), false),
|
||||||
|
|generator, ctx, _, i| {
|
||||||
|
let idx = unsafe { indices.get_typed_unchecked(ctx, generator, &i, None) };
|
||||||
|
let idx = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_z_extend_or_bit_cast(idx, llvm_usize, "")
|
||||||
|
.unwrap();
|
||||||
|
unsafe {
|
||||||
|
indices_usize.set_typed_unchecked(ctx, generator, &i, idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
indices_usize
|
||||||
|
};
|
||||||
|
|
||||||
|
let elem = unsafe { ndarray.data().get_unchecked(ctx, generator, &indices, None) };
|
||||||
|
|
||||||
|
llvm_intrinsics::call_stackrestore(ctx, stackptr);
|
||||||
|
|
||||||
|
elem
|
||||||
|
};
|
||||||
|
|
||||||
// Assert that all ndarray operands are broadcastable to the target size
|
// Assert that all ndarray operands are broadcastable to the target size
|
||||||
if !lhs_scalar {
|
if !lhs_scalar {
|
||||||
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
|
let lhs_val = NDArrayType::from_unifier_type(generator, ctx, lhs_ty)
|
||||||
@ -455,7 +510,7 @@ where
|
|||||||
.map_value(lhs_val.into_pointer_value(), None);
|
.map_value(lhs_val.into_pointer_value(), None);
|
||||||
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
|
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, idx);
|
||||||
|
|
||||||
unsafe { lhs.data().get_unchecked(ctx, generator, &lhs_idx, None) }
|
get_data_by_indices_compat(generator, ctx, lhs, lhs_idx)
|
||||||
};
|
};
|
||||||
|
|
||||||
let rhs_elem = if rhs_scalar {
|
let rhs_elem = if rhs_scalar {
|
||||||
@ -465,7 +520,7 @@ where
|
|||||||
.map_value(rhs_val.into_pointer_value(), None);
|
.map_value(rhs_val.into_pointer_value(), None);
|
||||||
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
|
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, idx);
|
||||||
|
|
||||||
unsafe { rhs.data().get_unchecked(ctx, generator, &rhs_idx, None) }
|
get_data_by_indices_compat(generator, ctx, rhs, rhs_idx)
|
||||||
};
|
};
|
||||||
|
|
||||||
value_fn(generator, ctx, (lhs_elem, rhs_elem))
|
value_fn(generator, ctx, (lhs_elem, rhs_elem))
|
||||||
@ -1408,7 +1463,6 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|||||||
lhs: NDArrayValue<'ctx>,
|
lhs: NDArrayValue<'ctx>,
|
||||||
rhs: NDArrayValue<'ctx>,
|
rhs: NDArrayValue<'ctx>,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
if cfg!(debug_assertions) {
|
if cfg!(debug_assertions) {
|
||||||
@ -1597,19 +1651,19 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|||||||
|
|
||||||
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None);
|
||||||
|
|
||||||
ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap()
|
ctx.builder.build_int_z_extend_or_bit_cast(idx, llvm_usize, "").unwrap()
|
||||||
};
|
};
|
||||||
|
|
||||||
let idx0 = unsafe {
|
let idx0 = unsafe {
|
||||||
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None);
|
||||||
|
|
||||||
ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap()
|
ctx.builder.build_int_z_extend_or_bit_cast(idx0, llvm_usize, "").unwrap()
|
||||||
};
|
};
|
||||||
let idx1 = unsafe {
|
let idx1 = unsafe {
|
||||||
let idx1 =
|
let idx1 =
|
||||||
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
|
idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None);
|
||||||
|
|
||||||
ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap()
|
ctx.builder.build_int_z_extend_or_bit_cast(idx1, llvm_usize, "").unwrap()
|
||||||
};
|
};
|
||||||
|
|
||||||
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?;
|
||||||
@ -1620,14 +1674,12 @@ pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>(
|
|||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
None,
|
None,
|
||||||
llvm_i32.const_zero(),
|
llvm_usize.const_zero(),
|
||||||
(common_dim, false),
|
(common_dim, false),
|
||||||
|generator, ctx, _, i| {
|
|generator, ctx, _, i| {
|
||||||
let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap();
|
|
||||||
|
|
||||||
let ab_idx = generator.gen_array_var_alloc(
|
let ab_idx = generator.gen_array_var_alloc(
|
||||||
ctx,
|
ctx,
|
||||||
llvm_i32.into(),
|
llvm_usize.into(),
|
||||||
llvm_usize.const_int(2, false),
|
llvm_usize.const_int(2, false),
|
||||||
None,
|
None,
|
||||||
)?;
|
)?;
|
||||||
@ -2002,7 +2054,7 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
|
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
|
||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
let n1 = llvm_ndarray_ty.map_value(n1, None);
|
let n1 = llvm_ndarray_ty.map_value(n1, None);
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
let n_sz = n1.size(generator, ctx);
|
||||||
|
|
||||||
// Dimensions are reversed in the transposed array
|
// Dimensions are reversed in the transposed array
|
||||||
let out = create_ndarray_dyn_shape(
|
let out = create_ndarray_dyn_shape(
|
||||||
@ -2122,7 +2174,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
||||||
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
|
let llvm_ndarray_ty = NDArrayType::from_unifier_type(generator, ctx, x1_ty);
|
||||||
let n1 = llvm_ndarray_ty.map_value(n1, None);
|
let n1 = llvm_ndarray_ty.map_value(n1, None);
|
||||||
let n_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
let n_sz = n1.size(generator, ctx);
|
||||||
|
|
||||||
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
let acc = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
let num_neg = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||||
@ -2350,7 +2402,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
);
|
);
|
||||||
|
|
||||||
// The new shape must be compatible with the old shape
|
// The new shape must be compatible with the old shape
|
||||||
let out_sz = call_ndarray_calc_size(generator, ctx, &out.shape(), (None, None));
|
let out_sz = out.size(generator, ctx);
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
generator,
|
generator,
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
||||||
@ -2407,8 +2459,8 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None);
|
let n1 = NDArrayType::from_unifier_type(generator, ctx, x1_ty).map_value(n1, None);
|
||||||
let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None);
|
let n2 = NDArrayType::from_unifier_type(generator, ctx, x2_ty).map_value(n2, None);
|
||||||
|
|
||||||
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
let n1_sz = n1.size(generator, ctx);
|
||||||
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.shape(), (None, None));
|
let n2_sz = n2.size(generator, ctx);
|
||||||
|
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
generator,
|
generator,
|
||||||
|
@ -5,8 +5,8 @@ use inkwell::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeMutator,
|
ArrayLikeIndexer, ArrayLikeValue, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
||||||
UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
};
|
};
|
||||||
use crate::codegen::{
|
use crate::codegen::{
|
||||||
irrt,
|
irrt,
|
||||||
@ -671,12 +671,7 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
|||||||
ctx: &CodeGenContext<'ctx, '_>,
|
ctx: &CodeGenContext<'ctx, '_>,
|
||||||
generator: &G,
|
generator: &G,
|
||||||
) -> IntValue<'ctx> {
|
) -> IntValue<'ctx> {
|
||||||
irrt::ndarray::call_ndarray_calc_size(
|
irrt::ndarray::call_nac3_ndarray_len(generator, ctx, *self.0)
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
&self.as_slice_value(ctx, generator),
|
|
||||||
(None, None),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -688,24 +683,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
|||||||
idx: &IntValue<'ctx>,
|
idx: &IntValue<'ctx>,
|
||||||
name: Option<&str>,
|
name: Option<&str>,
|
||||||
) -> PointerValue<'ctx> {
|
) -> PointerValue<'ctx> {
|
||||||
let sizeof_elem = ctx
|
let ptr = irrt::ndarray::call_nac3_ndarray_get_nth_pelement(generator, ctx, *self.0, *idx);
|
||||||
.builder
|
|
||||||
.build_int_truncate_or_bit_cast(
|
|
||||||
self.element_type(ctx, generator).size_of().unwrap(),
|
|
||||||
idx.get_type(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let idx = ctx.builder.build_int_mul(*idx, sizeof_elem, "").unwrap();
|
|
||||||
let ptr = unsafe {
|
|
||||||
ctx.builder
|
|
||||||
.build_in_bounds_gep(
|
|
||||||
self.base_ptr(ctx, generator),
|
|
||||||
&[idx],
|
|
||||||
name.unwrap_or_default(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Current implementation is transparent - The returned pointer type is
|
// Current implementation is transparent - The returned pointer type is
|
||||||
// already cast into the expected type, allowing for immediately
|
// already cast into the expected type, allowing for immediately
|
||||||
@ -716,7 +694,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
|
|||||||
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.ptr_type(AddressSpace::default()),
|
.ptr_type(AddressSpace::default()),
|
||||||
"",
|
name.unwrap_or_default(),
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
@ -769,52 +747,28 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
|
|||||||
indices: &Index,
|
indices: &Index,
|
||||||
name: Option<&str>,
|
name: Option<&str>,
|
||||||
) -> PointerValue<'ctx> {
|
) -> PointerValue<'ctx> {
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
assert_eq!(indices.element_type(ctx, generator), generator.get_size_type(ctx.ctx).into());
|
||||||
|
|
||||||
let indices_elem_ty = unsafe {
|
let indices = TypedArrayLikeAdapter::from(
|
||||||
indices
|
indices.as_slice_value(ctx, generator),
|
||||||
.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
|
|_, _, v| v.into_int_value(),
|
||||||
.get_type()
|
|_, _, v| v.into(),
|
||||||
.get_element_type()
|
|
||||||
};
|
|
||||||
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
|
|
||||||
panic!("Expected list[int32] but got {indices_elem_ty}")
|
|
||||||
};
|
|
||||||
assert_eq!(
|
|
||||||
indices_elem_ty.get_bit_width(),
|
|
||||||
32,
|
|
||||||
"Expected list[int32] but got list[int{}]",
|
|
||||||
indices_elem_ty.get_bit_width()
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let index = irrt::ndarray::call_ndarray_flatten_index(generator, ctx, *self.0, indices);
|
let ptr = irrt::ndarray::call_nac3_ndarray_get_pelement_by_indices(
|
||||||
let sizeof_elem = ctx
|
generator, ctx, *self.0, &indices,
|
||||||
.builder
|
);
|
||||||
.build_int_truncate_or_bit_cast(
|
|
||||||
self.element_type(ctx, generator).size_of().unwrap(),
|
|
||||||
index.get_type(),
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let index = ctx.builder.build_int_mul(index, sizeof_elem, "").unwrap();
|
|
||||||
|
|
||||||
let ptr = unsafe {
|
// Current implementation is transparent - The returned pointer type is
|
||||||
ctx.builder
|
// already cast into the expected type, allowing for immediately
|
||||||
.build_in_bounds_gep(
|
// load/store.
|
||||||
self.base_ptr(ctx, generator),
|
|
||||||
&[index],
|
|
||||||
name.unwrap_or_default(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
// TODO: Current implementation is transparent
|
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_pointer_cast(
|
.build_pointer_cast(
|
||||||
ptr,
|
ptr,
|
||||||
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
BasicTypeEnum::try_from(self.element_type(ctx, generator))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.ptr_type(AddressSpace::default()),
|
.ptr_type(AddressSpace::default()),
|
||||||
"",
|
name.unwrap_or_default(),
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user