WIP
This commit is contained in:
parent
aea52af1a6
commit
fa04cfcdc8
|
@ -160,6 +160,8 @@ pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayLi
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: Add TypedArrayLikeAccessAdapter and TypedArrayLikeMutateAdapter
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ArrayAllocaValue<'ctx>(PointerValue<'ctx>, IntValue<'ctx>, Option<&'ctx str>);
|
||||
|
||||
|
|
|
@ -1122,9 +1122,12 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
Some("f_pow_i")
|
||||
);
|
||||
Ok(Some(res.into()))
|
||||
} else if ty1 == ty2 && matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) {
|
||||
} else if matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) && matches!(&*ctx.unifier.get_ty(ty2), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
||||
|
||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
let left_val = NDArrayValue::from_ptr_val(
|
||||
left_val.into_pointer_value(),
|
||||
|
@ -1139,7 +1142,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype,
|
||||
ndarray_dtype1,
|
||||
if is_aug_assign { Some(left_val) } else { None },
|
||||
left_val,
|
||||
right_val,
|
||||
|
|
|
@ -355,3 +355,27 @@ void __nac3_ndarray_calc_broadcast_sz64(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast_idx(
|
||||
const uint32_t *src_dims,
|
||||
uint32_t src_ndims,
|
||||
const uint32_t *in_idx,
|
||||
uint32_t *out_idx
|
||||
) {
|
||||
for (uint32_t i = 0; i < src_ndims; ++i) {
|
||||
uint32_t src_i = src_ndims - i - 1;
|
||||
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_broadcast_idx64(
|
||||
const uint64_t *src_dims,
|
||||
uint64_t src_ndims,
|
||||
const uint64_t *in_idx,
|
||||
uint64_t *out_idx
|
||||
) {
|
||||
for (uint64_t i = 0; i < src_ndims; ++i) {
|
||||
uint64_t src_i = src_ndims - i - 1;
|
||||
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ use inkwell::{
|
|||
};
|
||||
use itertools::Either;
|
||||
use nac3parser::ast::Expr;
|
||||
use crate::codegen::classes::ArrayAllocaValue;
|
||||
|
||||
#[must_use]
|
||||
pub fn load_irrt(ctx: &Context) -> Module {
|
||||
|
@ -629,7 +630,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
|
|||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
index: IntValue<'ctx>,
|
||||
ndarray: NDArrayValue<'ctx>,
|
||||
) -> PointerValue<'ctx> {
|
||||
) -> ArrayAllocaValue<'ctx> {
|
||||
let llvm_void = ctx.ctx.void_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
|
@ -676,7 +677,7 @@ pub fn call_ndarray_calc_nd_indices<'ctx>(
|
|||
)
|
||||
.unwrap();
|
||||
|
||||
indices
|
||||
ArrayAllocaValue::from_ptr_val(indices, ndarray_num_dims, None)
|
||||
}
|
||||
|
||||
fn call_ndarray_flatten_index_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||
|
@ -942,4 +943,67 @@ pub fn call_ndarray_calc_broadcast_sz<'ctx, G: CodeGenerator + ?Sized>(
|
|||
.unwrap();
|
||||
|
||||
(max_ndims, out_dims)
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
|
||||
/// containing the indices used for accessing `array` corresponding to the `broadcast_idx`.
|
||||
pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
array: NDArrayValue<'ctx>,
|
||||
broadcast_idx: BroadcastIdx,
|
||||
) -> ArrayAllocaValue<'ctx> {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_broadcast_idx",
|
||||
64 => "__nac3_ndarray_calc_broadcast_idx64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
||||
};
|
||||
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_usize.fn_type(
|
||||
&[
|
||||
llvm_pusize.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
llvm_usize.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
// TODO: Assertions
|
||||
|
||||
let broadcast_size = broadcast_idx.size(ctx, generator);
|
||||
let out_idx = ctx.builder.build_array_alloca(llvm_usize, broadcast_size, "").unwrap();
|
||||
let out_idx = ArrayAllocaValue::from_ptr_val(out_idx, broadcast_size, None);
|
||||
|
||||
let array_dims = array.dim_sizes().as_ptr_value(ctx);
|
||||
let array_ndims = array.load_ndims(ctx);
|
||||
let broadcast_idx_ptr = unsafe {
|
||||
broadcast_idx.ptr_offset_unchecked(
|
||||
ctx,
|
||||
generator,
|
||||
llvm_usize.const_zero(),
|
||||
None
|
||||
)
|
||||
};
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_broadcast_fn,
|
||||
&[
|
||||
array_dims.into(),
|
||||
array_ndims.into(),
|
||||
broadcast_idx_ptr.into(),
|
||||
out_idx.as_ptr_value().into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
out_idx
|
||||
}
|
|
@ -30,6 +30,7 @@ use crate::{
|
|||
},
|
||||
typecheck::typedef::{FunSignature, Type},
|
||||
};
|
||||
use crate::codegen::classes::ArrayAllocaValue;
|
||||
|
||||
/// Creates an `NDArray` instance from a dynamic shape.
|
||||
///
|
||||
|
@ -326,7 +327,7 @@ fn ndarray_fill_indexed<'ctx, ValueFn>(
|
|||
value_fn: ValueFn,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, '_>, PointerValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, '_>, ArrayAllocaValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
ndarray_fill_flattened(
|
||||
generator,
|
||||
|
@ -347,7 +348,7 @@ fn ndarray_fill_indexed<'ctx, ValueFn>(
|
|||
|
||||
/// Generates the LLVM IR for populating the entire `NDArray` using a lambda with the same-indexed
|
||||
/// element from two other `NDArray` as its input.
|
||||
fn ndarray_fill_zip_map_flattened<'ctx, G, ValueFn>(
|
||||
fn ndarray_broadcast_fill_flattened<'ctx, G, ValueFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
|
@ -536,16 +537,12 @@ fn call_ndarray_eye_impl<'ctx>(
|
|||
ctx,
|
||||
ndarray,
|
||||
|generator, ctx, indices| {
|
||||
let row = ctx.build_gep_and_load(
|
||||
indices,
|
||||
&[llvm_usize.const_int(0, false)],
|
||||
None,
|
||||
).into_int_value();
|
||||
let col = ctx.build_gep_and_load(
|
||||
indices,
|
||||
&[llvm_usize.const_int(1, false)],
|
||||
None,
|
||||
).into_int_value();
|
||||
let (row, col) = unsafe {
|
||||
(
|
||||
indices.get_unchecked(ctx, generator, llvm_usize.const_int(0, false), None).into_int_value(),
|
||||
indices.get_unchecked(ctx, generator, llvm_usize.const_int(1, false), None).into_int_value(),
|
||||
)
|
||||
};
|
||||
|
||||
let col_with_offset = ctx.builder
|
||||
.build_int_add(
|
||||
|
@ -662,7 +659,7 @@ pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
|
|||
).unwrap()
|
||||
});
|
||||
|
||||
ndarray_fill_zip_map_flattened(
|
||||
ndarray_broadcast_fill_flattened(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
|
|
|
@ -299,6 +299,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
Some("N".into()),
|
||||
None,
|
||||
);
|
||||
let size_t = primitives.0.usize();
|
||||
|
||||
let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect();
|
||||
let exception_fields = vec![
|
||||
("__name__".into(), int32, true),
|
||||
|
@ -345,6 +347,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
.nth(1)
|
||||
.map(|(var_id, ty)| (*ty, *var_id))
|
||||
.unwrap();
|
||||
let ndarray_usized_ndims_tvar = primitives.1.get_fresh_const_generic_var(
|
||||
size_t,
|
||||
Some("ndarray_ndims".into()),
|
||||
None,
|
||||
);
|
||||
let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap();
|
||||
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
|
||||
let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap();
|
||||
|
@ -699,7 +706,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
name: "ndarray.__iadd__".into(),
|
||||
simple_name: "__iadd__".into(),
|
||||
signature: ndarray_iadd_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id, ndarray_usized_ndims_tvar.1],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
|
|
|
@ -285,8 +285,11 @@ impl TopLevelComposer {
|
|||
]),
|
||||
});
|
||||
|
||||
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
|
||||
let ndarray_unsized = subst_ndarray_tvars(&mut unifier, ndarray, Some(ndarray_usized_ndims_tvar.0), None);
|
||||
|
||||
unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap();
|
||||
unifier.unify(ndarray_binop_fun_other_ty.0, ndarray).unwrap();
|
||||
unifier.unify(ndarray_binop_fun_other_ty.0, ndarray_unsized).unwrap();
|
||||
unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap();
|
||||
|
||||
let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None);
|
||||
|
|
|
@ -309,6 +309,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
ndarray: ndarray_t,
|
||||
..
|
||||
} = *store;
|
||||
let size_t = store.usize();
|
||||
|
||||
/* int ======== */
|
||||
for t in [int32_t, int64_t, uint32_t, uint64_t] {
|
||||
|
@ -345,9 +346,11 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
|
||||
/* ndarray ===== */
|
||||
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
|
||||
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_t], ndarray_t);
|
||||
impl_pow(unifier, store, ndarray_t, &[ndarray_t], ndarray_t);
|
||||
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
||||
let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0));
|
||||
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
|
||||
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
|
||||
impl_div(unifier, store, ndarray_t, &[ndarray_t], ndarray_float_t);
|
||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_t], ndarray_t);
|
||||
impl_mod(unifier, store, ndarray_t, &[ndarray_t], ndarray_t);
|
||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
|
||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t], ndarray_t);
|
||||
}
|
||||
|
|
|
@ -81,6 +81,20 @@ def test_ndarray_add():
|
|||
output_float64(y[1][0])
|
||||
output_float64(y[1][1])
|
||||
|
||||
# def test_ndarray_add_broadcast():
|
||||
# x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x + np_ones([2])
|
||||
#
|
||||
# output_float64(x[0][0])
|
||||
# output_float64(x[0][1])
|
||||
# output_float64(x[1][0])
|
||||
# output_float64(x[1][1])
|
||||
#
|
||||
# output_float64(y[0][0])
|
||||
# output_float64(y[0][1])
|
||||
# output_float64(y[1][0])
|
||||
# output_float64(y[1][1])
|
||||
|
||||
def test_ndarray_iadd():
|
||||
x = np_identity(2)
|
||||
x += np_ones([2, 2])
|
||||
|
|
Loading…
Reference in New Issue