core: Add handling of ndarrays in gen_binop_expr
This commit is contained in:
parent
4887cd8007
commit
ddfd19d00c
|
@ -16,6 +16,7 @@ use crate::{
|
||||||
get_llvm_abi_type,
|
get_llvm_abi_type,
|
||||||
irrt::*,
|
irrt::*,
|
||||||
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
|
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
|
||||||
|
numpy,
|
||||||
stmt::{gen_raise, gen_var},
|
stmt::{gen_raise, gen_var},
|
||||||
CodeGenContext, CodeGenTask,
|
CodeGenContext, CodeGenTask,
|
||||||
},
|
},
|
||||||
|
@ -23,7 +24,7 @@ use crate::{
|
||||||
toplevel::{
|
toplevel::{
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
helper::PRIMITIVE_DEF_IDS,
|
helper::PRIMITIVE_DEF_IDS,
|
||||||
numpy::make_ndarray_ty,
|
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
|
||||||
TopLevelDef,
|
TopLevelDef,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
|
@ -1129,6 +1130,41 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
Some("f_pow_i")
|
Some("f_pow_i")
|
||||||
);
|
);
|
||||||
Ok(Some(res.into()))
|
Ok(Some(res.into()))
|
||||||
|
} else if ty1 == ty2 && matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let (ndarray_dtype, _) = unpack_ndarray_tvars(&mut ctx.unifier, ty1);
|
||||||
|
|
||||||
|
let left_val = NDArrayValue::from_ptr_val(
|
||||||
|
left_val.into_pointer_value(),
|
||||||
|
llvm_usize,
|
||||||
|
None
|
||||||
|
);
|
||||||
|
let right_val = NDArrayValue::from_ptr_val(
|
||||||
|
right_val.into_pointer_value(),
|
||||||
|
llvm_usize,
|
||||||
|
None
|
||||||
|
);
|
||||||
|
let res = numpy::ndarray_elementwise_binop_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray_dtype,
|
||||||
|
if is_aug_assign { Some(left_val) } else { None },
|
||||||
|
left_val,
|
||||||
|
right_val,
|
||||||
|
|generator, ctx, elem_ty, (lhs, rhs)| {
|
||||||
|
gen_binop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(&Some(elem_ty), lhs),
|
||||||
|
op,
|
||||||
|
(&Some(elem_ty), rhs),
|
||||||
|
ctx.current_loc,
|
||||||
|
is_aug_assign,
|
||||||
|
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Some(res.as_ptr_value().into()))
|
||||||
} else {
|
} else {
|
||||||
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
||||||
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {
|
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {
|
||||||
|
|
|
@ -344,6 +344,42 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates the LLVM IR for populating the entire `NDArray` using a lambda with the same-indexed
|
||||||
|
/// element from two other `NDArray` as its input.
|
||||||
|
fn ndarray_fill_zip_map_flattened<'ctx, G, ValueFn>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
res: NDArrayValue<'ctx>,
|
||||||
|
lhs: NDArrayValue<'ctx>,
|
||||||
|
rhs: NDArrayValue<'ctx>,
|
||||||
|
value_fn: ValueFn,
|
||||||
|
) -> Result<NDArrayValue<'ctx>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
ndarray_fill_flattened(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
res,
|
||||||
|
|generator, ctx, idx| {
|
||||||
|
let elem = unsafe {
|
||||||
|
(
|
||||||
|
lhs.data().get_unchecked(ctx, generator, idx, None),
|
||||||
|
rhs.data().get_unchecked(ctx, generator, idx, None),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
debug_assert_eq!(elem.0.get_type(), elem.1.get_type());
|
||||||
|
|
||||||
|
value_fn(generator, ctx, elem_ty, elem)
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
|
|
|
@ -546,7 +546,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
|
||||||
/// body(x);
|
/// body(x);
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
///
|
///
|
||||||
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
|
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
|
||||||
/// as the type of the loop variable.
|
/// as the type of the loop variable.
|
||||||
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum
|
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum
|
||||||
|
|
Loading…
Reference in New Issue