This commit is contained in:
David Mak 2024-03-19 17:38:09 +08:00
parent 2d1f243975
commit b2994ff90a
9 changed files with 148 additions and 26 deletions

View File

@ -11,7 +11,7 @@ use crate::codegen::{
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
}; };
/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of /// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of
/// elements. /// elements.
pub trait ArrayLikeValue<'ctx> { pub trait ArrayLikeValue<'ctx> {
/// Returns the element type of this array-like value. /// Returns the element type of this array-like value.

View File

@ -1129,9 +1129,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
Some("f_pow_i") Some("f_pow_i")
); );
Ok(Some(res.into())) Ok(Some(res.into()))
} else if ty1 == ty2 && matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) { } else if matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) && matches!(&*ctx.unifier.get_ty(ty2), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) {
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
let left_val = NDArrayValue::from_ptr_val( let left_val = NDArrayValue::from_ptr_val(
left_val.into_pointer_value(), left_val.into_pointer_value(),
@ -1146,7 +1149,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
let res = numpy::ndarray_elementwise_binop_impl( let res = numpy::ndarray_elementwise_binop_impl(
generator, generator,
ctx, ctx,
ndarray_dtype, ndarray_dtype1,
if is_aug_assign { Some(left_val) } else { None }, if is_aug_assign { Some(left_val) } else { None },
left_val, left_val,
right_val, right_val,

View File

@ -355,3 +355,27 @@ void __nac3_ndarray_calc_broadcast64(
} }
} }
} }
void __nac3_ndarray_calc_broadcast_idx(
const uint32_t *src_dims,
uint32_t src_ndims,
const uint32_t *in_idx,
uint32_t *out_idx
) {
for (uint32_t i = 0; i < src_ndims; ++i) {
uint32_t src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
void __nac3_ndarray_calc_broadcast_idx64(
const uint64_t *src_dims,
uint64_t src_ndims,
const uint64_t *in_idx,
uint64_t *out_idx
) {
for (uint64_t i = 0; i < src_ndims; ++i) {
uint64_t src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}

View File

@ -1,7 +1,15 @@
use crate::typecheck::typedef::Type; use crate::typecheck::typedef::Type;
use super::{ use super::{
classes::{ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, UntypedArrayLikeMutator}, classes::{
ArrayLikeIndexer,
ArraySliceValue,
ArrayLikeValue,
ListValue,
NDArrayValue,
UntypedArrayLikeAccessor,
UntypedArrayLikeMutator,
},
CodeGenContext, CodeGenContext,
CodeGenerator, CodeGenerator,
llvm_intrinsics, llvm_intrinsics,
@ -630,7 +638,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
index: IntValue<'ctx>, index: IntValue<'ctx>,
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
) -> PointerValue<'ctx> { ) -> ArraySliceValue<'ctx> {
let llvm_void = ctx.ctx.void_type(); let llvm_void = ctx.ctx.void_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
@ -677,7 +685,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
) )
.unwrap(); .unwrap();
indices ArraySliceValue::from_ptr_val(indices, ndarray_num_dims, None)
} }
fn call_ndarray_flatten_index_impl<'ctx, G, Indices>( fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
@ -889,4 +897,67 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
.unwrap(); .unwrap();
(max_ndims, out_dims) (max_ndims, out_dims)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> ArraySliceValue<'ctx> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
// TODO: Assertions
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_usize, broadcast_size, "").unwrap();
let out_idx = ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None);
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(
ctx,
generator,
llvm_usize.const_zero(),
None
)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
array_dims.into(),
array_ndims.into(),
broadcast_idx_ptr.into(),
out_idx.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
out_idx
} }

View File

@ -8,6 +8,7 @@ use crate::{
codegen::{ codegen::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeIndexer,
ArraySliceValue,
ArrayLikeValue, ArrayLikeValue,
ListValue, ListValue,
NDArrayValue, NDArrayValue,
@ -325,7 +326,7 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
) -> Result<(), String> ) -> Result<(), String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, PointerValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>, ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, ArraySliceValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
{ {
ndarray_fill_flattened( ndarray_fill_flattened(
generator, generator,
@ -346,7 +347,7 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
/// Generates the LLVM IR for populating the entire `NDArray` using a lambda with the same-indexed /// Generates the LLVM IR for populating the entire `NDArray` using a lambda with the same-indexed
/// element from two other `NDArray` as its input. /// element from two other `NDArray` as its input.
fn ndarray_fill_zip_map_flattened<'ctx, G, ValueFn>( fn ndarray_broadcast_fill_flattened<'ctx, G, ValueFn>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type, elem_ty: Type,
@ -535,16 +536,12 @@ fn call_ndarray_eye_impl<'ctx, G: CodeGenerator + ?Sized>(
ctx, ctx,
ndarray, ndarray,
|generator, ctx, indices| { |generator, ctx, indices| {
let row = ctx.build_gep_and_load( let (row, col) = unsafe {
indices, (
&[llvm_usize.const_int(0, false)], indices.get_unchecked(ctx, generator, llvm_usize.const_int(0, false), None).into_int_value(),
None, indices.get_unchecked(ctx, generator, llvm_usize.const_int(1, false), None).into_int_value(),
).into_int_value(); )
let col = ctx.build_gep_and_load( };
indices,
&[llvm_usize.const_int(1, false)],
None,
).into_int_value();
let col_with_offset = ctx.builder let col_with_offset = ctx.builder
.build_int_add( .build_int_add(
@ -660,7 +657,7 @@ pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
).unwrap() ).unwrap()
}); });
ndarray_fill_zip_map_flattened( ndarray_broadcast_fill_flattened(
generator, generator,
ctx, ctx,
elem_ty, elem_ty,

View File

@ -299,6 +299,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Some("N".into()), Some("N".into()),
None, None,
); );
let size_t = primitives.0.usize();
let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect(); let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect();
let exception_fields = vec![ let exception_fields = vec![
("__name__".into(), int32, true), ("__name__".into(), int32, true),
@ -345,6 +347,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
.nth(1) .nth(1)
.map(|(var_id, ty)| (*ty, *var_id)) .map(|(var_id, ty)| (*ty, *var_id))
.unwrap(); .unwrap();
let ndarray_usized_ndims_tvar = primitives.1.get_fresh_const_generic_var(
size_t,
Some("ndarray_ndims".into()),
None,
);
let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap(); let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap();
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap(); let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap(); let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap();
@ -699,7 +706,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
name: "ndarray.__iadd__".into(), name: "ndarray.__iadd__".into(),
simple_name: "__iadd__".into(), simple_name: "__iadd__".into(),
signature: ndarray_iadd_ty.0, signature: ndarray_iadd_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id, ndarray_usized_ndims_tvar.1],
instance_to_symbol: HashMap::default(), instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(), instance_to_stmt: HashMap::default(),
resolver: None, resolver: None,

View File

@ -285,8 +285,11 @@ impl TopLevelComposer {
]), ]),
}); });
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_unsized = subst_ndarray_tvars(&mut unifier, ndarray, Some(ndarray_usized_ndims_tvar.0), None);
unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap(); unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap();
unifier.unify(ndarray_binop_fun_other_ty.0, ndarray).unwrap(); unifier.unify(ndarray_binop_fun_other_ty.0, ndarray_unsized).unwrap();
unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap(); unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap();
let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None); let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None);

View File

@ -309,6 +309,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
ndarray: ndarray_t, ndarray: ndarray_t,
.. ..
} = *store; } = *store;
let size_t = store.usize();
/* int ======== */ /* int ======== */
for t in [int32_t, int64_t, uint32_t, uint64_t] { for t in [int32_t, int64_t, uint32_t, uint64_t] {
@ -345,9 +346,11 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
/* ndarray ===== */ /* ndarray ===== */
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None); let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_t], ndarray_t); let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
impl_pow(unifier, store, ndarray_t, &[ndarray_t], ndarray_t); let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0));
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
impl_div(unifier, store, ndarray_t, &[ndarray_t], ndarray_float_t); impl_div(unifier, store, ndarray_t, &[ndarray_t], ndarray_float_t);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_t], ndarray_t); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
impl_mod(unifier, store, ndarray_t, &[ndarray_t], ndarray_t); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
} }

View File

@ -81,6 +81,20 @@ def test_ndarray_add():
output_float64(y[1][0]) output_float64(y[1][0])
output_float64(y[1][1]) output_float64(y[1][1])
# def test_ndarray_add_broadcast():
# x = np_identity(2)
# y: ndarray[float, 2] = x + np_ones([2])
#
# output_float64(x[0][0])
# output_float64(x[0][1])
# output_float64(x[1][0])
# output_float64(x[1][1])
#
# output_float64(y[0][0])
# output_float64(y[0][1])
# output_float64(y[1][0])
# output_float64(y[1][1])
def test_ndarray_iadd(): def test_ndarray_iadd():
x = np_identity(2) x = np_identity(2)
x += np_ones([2, 2]) x += np_ones([2, 2])