core: Add handling of ndarrays in gen_binop_expr
This commit is contained in:
parent
e3fe3f03fb
commit
4f236ea411
|
@ -16,6 +16,7 @@ use crate::{
|
|||
get_llvm_abi_type,
|
||||
irrt::*,
|
||||
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
|
||||
numpy,
|
||||
stmt::{gen_raise, gen_var},
|
||||
CodeGenContext, CodeGenTask,
|
||||
},
|
||||
|
@ -23,7 +24,7 @@ use crate::{
|
|||
toplevel::{
|
||||
DefinitionId,
|
||||
helper::PRIMITIVE_DEF_IDS,
|
||||
numpy::make_ndarray_ty,
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
|
||||
TopLevelDef,
|
||||
},
|
||||
typecheck::{
|
||||
|
@ -1128,6 +1129,41 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
Some("f_pow_i")
|
||||
);
|
||||
Ok(Some(res.into()))
|
||||
} else if ty1 == ty2 && matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let (ndarray_dtype, _) = unpack_ndarray_tvars(&mut ctx.unifier, ty1);
|
||||
|
||||
let left_val = NDArrayValue::from_ptr_val(
|
||||
left_val.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None
|
||||
);
|
||||
let right_val = NDArrayValue::from_ptr_val(
|
||||
right_val.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None
|
||||
);
|
||||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype,
|
||||
if is_aug_assign { Some(left_val) } else { None },
|
||||
left_val,
|
||||
right_val,
|
||||
|generator, ctx, elem_ty, (lhs, rhs)| {
|
||||
gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(elem_ty), lhs),
|
||||
op,
|
||||
(&Some(elem_ty), rhs),
|
||||
ctx.current_loc,
|
||||
is_aug_assign,
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Some(res.as_ptr_value().into()))
|
||||
} else {
|
||||
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
||||
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {
|
||||
|
|
|
@ -17,6 +17,7 @@ use crate::{
|
|||
CodeGenContext,
|
||||
CodeGenerator,
|
||||
irrt::{
|
||||
call_ndarray_calc_broadcast,
|
||||
call_ndarray_calc_nd_indices,
|
||||
call_ndarray_calc_size,
|
||||
},
|
||||
|
@ -343,6 +344,42 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
|
|||
)
|
||||
}
|
||||
|
||||
/// Generates the LLVM IR for populating the entire `NDArray` using a lambda with the same-indexed
|
||||
/// element from two other `NDArray` as its input.
|
||||
fn ndarray_fill_zip_map_flattened<'ctx, G, ValueFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
res: NDArrayValue<'ctx>,
|
||||
lhs: NDArrayValue<'ctx>,
|
||||
rhs: NDArrayValue<'ctx>,
|
||||
value_fn: ValueFn,
|
||||
) -> Result<NDArrayValue<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
ndarray_fill_flattened(
|
||||
generator,
|
||||
ctx,
|
||||
res,
|
||||
|generator, ctx, idx| {
|
||||
let elem = unsafe {
|
||||
(
|
||||
lhs.data().get_unchecked(ctx, generator, idx, None),
|
||||
rhs.data().get_unchecked(ctx, generator, idx, None),
|
||||
)
|
||||
};
|
||||
|
||||
debug_assert_eq!(elem.0.get_type(), elem.1.get_type());
|
||||
|
||||
value_fn(generator, ctx, elem_ty, elem)
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
|
@ -581,6 +618,63 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for computing elementwise binary operations.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
||||
/// written to a new `ndarray`.
|
||||
/// * `value_fn` - Function mapping the two input elements into the result.
|
||||
pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
res: Option<NDArrayValue<'ctx>>,
|
||||
this: NDArrayValue<'ctx>,
|
||||
other: NDArrayValue<'ctx>,
|
||||
value_fn: ValueFn,
|
||||
) -> Result<NDArrayValue<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator,
|
||||
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, this, other);
|
||||
let ndarray = res.unwrap_or_else(|| {
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&ndarray_dims,
|
||||
|_, _, v| {
|
||||
Ok(v.0)
|
||||
},
|
||||
|_, ctx, v, idx| {
|
||||
unsafe {
|
||||
let data_ptr = ctx.builder.build_in_bounds_gep(v.1, &[idx], "")
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
ctx.builder.build_load(data_ptr, "")
|
||||
.map(BasicValueEnum::into_int_value)
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
},
|
||||
).unwrap()
|
||||
});
|
||||
|
||||
ndarray_fill_zip_map_flattened(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
ndarray,
|
||||
this,
|
||||
other,
|
||||
|generator, ctx, elem_ty, elems| {
|
||||
value_fn(generator, ctx, elem_ty, elems)
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.empty`.
|
||||
pub fn gen_ndarray_empty<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
|
|
|
@ -546,7 +546,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
|
|||
/// body(x);
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
///
|
||||
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
|
||||
/// as the type of the loop variable.
|
||||
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum
|
||||
|
|
Loading…
Reference in New Issue