core: Add handling of ndarrays in gen_binop_expr

This commit is contained in:
David Mak 2024-03-13 11:16:23 +08:00
parent e3fe3f03fb
commit 4f236ea411
3 changed files with 132 additions and 2 deletions

View File

@ -16,6 +16,7 @@ use crate::{
get_llvm_abi_type,
irrt::*,
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
numpy,
stmt::{gen_raise, gen_var},
CodeGenContext, CodeGenTask,
},
@ -23,7 +24,7 @@ use crate::{
toplevel::{
DefinitionId,
helper::PRIMITIVE_DEF_IDS,
numpy::make_ndarray_ty,
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
TopLevelDef,
},
typecheck::{
@ -1128,6 +1129,41 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
Some("f_pow_i")
);
Ok(Some(res.into()))
} else if ty1 == ty2 && matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_tvars(&mut ctx.unifier, ty1);
let left_val = NDArrayValue::from_ptr_val(
left_val.into_pointer_value(),
llvm_usize,
None
);
let right_val = NDArrayValue::from_ptr_val(
right_val.into_pointer_value(),
llvm_usize,
None
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ndarray_dtype,
if is_aug_assign { Some(left_val) } else { None },
left_val,
right_val,
|generator, ctx, elem_ty, (lhs, rhs)| {
gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), lhs),
op,
(&Some(elem_ty), rhs),
ctx.current_loc,
is_aug_assign,
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
},
)?;
Ok(Some(res.as_ptr_value().into()))
} else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {

View File

@ -17,6 +17,7 @@ use crate::{
CodeGenContext,
CodeGenerator,
irrt::{
call_ndarray_calc_broadcast,
call_ndarray_calc_nd_indices,
call_ndarray_calc_size,
},
@ -343,6 +344,42 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
)
}
/// Generates the LLVM IR for populating the entire `NDArray` using a lambda with the same-indexed
/// element from two other `NDArray` as its input.
fn ndarray_fill_zip_map_flattened<'ctx, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: NDArrayValue<'ctx>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
{
ndarray_fill_flattened(
generator,
ctx,
res,
|generator, ctx, idx| {
let elem = unsafe {
(
lhs.data().get_unchecked(ctx, generator, idx, None),
rhs.data().get_unchecked(ctx, generator, idx, None),
)
};
debug_assert_eq!(elem.0.get_type(), elem.1.get_type());
value_fn(generator, ctx, elem_ty, elem)
},
)?;
Ok(res)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
///
/// * `elem_ty` - The element type of the `NDArray`.
@ -581,6 +618,63 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
Ok(ndarray)
}
/// LLVM-typed implementation for computing elementwise binary operations.
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
/// written to a new `ndarray`.
/// * `value_fn` - Function mapping the two input elements into the result.
pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
this: NDArrayValue<'ctx>,
other: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
{
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, this, other);
let ndarray = res.unwrap_or_else(|| {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&ndarray_dims,
|_, _, v| {
Ok(v.0)
},
|_, ctx, v, idx| {
unsafe {
let data_ptr = ctx.builder.build_in_bounds_gep(v.1, &[idx], "")
.map_err(|e| e.to_string())?;
ctx.builder.build_load(data_ptr, "")
.map(BasicValueEnum::into_int_value)
.map_err(|e| e.to_string())
}
},
).unwrap()
});
ndarray_fill_zip_map_flattened(
generator,
ctx,
elem_ty,
ndarray,
this,
other,
|generator, ctx, elem_ty, elems| {
value_fn(generator, ctx, elem_ty, elems)
},
)?;
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,

View File

@ -546,7 +546,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
/// body(x);
/// }
/// ```
///
///
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
/// as the type of the loop variable.
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum