core: Add handling of ndarrays in gen_binop_expr

This commit is contained in:
David Mak 2024-03-13 11:16:23 +08:00
parent 4887cd8007
commit ddfd19d00c
3 changed files with 74 additions and 2 deletions

View File

@ -16,6 +16,7 @@ use crate::{
get_llvm_abi_type,
irrt::*,
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
numpy,
stmt::{gen_raise, gen_var},
CodeGenContext, CodeGenTask,
},
@ -23,7 +24,7 @@ use crate::{
toplevel::{
DefinitionId,
helper::PRIMITIVE_DEF_IDS,
numpy::make_ndarray_ty,
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
TopLevelDef,
},
typecheck::{
@ -1129,6 +1130,41 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
Some("f_pow_i")
);
Ok(Some(res.into()))
} else if ty1 == ty2 && matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_tvars(&mut ctx.unifier, ty1);
let left_val = NDArrayValue::from_ptr_val(
left_val.into_pointer_value(),
llvm_usize,
None
);
let right_val = NDArrayValue::from_ptr_val(
right_val.into_pointer_value(),
llvm_usize,
None
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ndarray_dtype,
if is_aug_assign { Some(left_val) } else { None },
left_val,
right_val,
|generator, ctx, elem_ty, (lhs, rhs)| {
gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), lhs),
op,
(&Some(elem_ty), rhs),
ctx.current_loc,
is_aug_assign,
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
},
)?;
Ok(Some(res.as_ptr_value().into()))
} else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {

View File

@ -344,6 +344,42 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
)
}
/// Generates the LLVM IR for populating the entire `NDArray` using a lambda with the same-indexed
/// element from two other `NDArray` as its input.
fn ndarray_fill_zip_map_flattened<'ctx, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: NDArrayValue<'ctx>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
{
ndarray_fill_flattened(
generator,
ctx,
res,
|generator, ctx, idx| {
let elem = unsafe {
(
lhs.data().get_unchecked(ctx, generator, idx, None),
rhs.data().get_unchecked(ctx, generator, idx, None),
)
};
debug_assert_eq!(elem.0.get_type(), elem.1.get_type());
value_fn(generator, ctx, elem_ty, elem)
},
)?;
Ok(res)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
///
/// * `elem_ty` - The element type of the `NDArray`.

View File

@ -546,7 +546,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
/// body(x);
/// }
/// ```
///
///
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
/// as the type of the loop variable.
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum