core: Add handling of ndarrays in gen_binop_expr
This commit is contained in:
parent
e3fe3f03fb
commit
4f236ea411
|
@ -16,6 +16,7 @@ use crate::{
|
||||||
get_llvm_abi_type,
|
get_llvm_abi_type,
|
||||||
irrt::*,
|
irrt::*,
|
||||||
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
|
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
|
||||||
|
numpy,
|
||||||
stmt::{gen_raise, gen_var},
|
stmt::{gen_raise, gen_var},
|
||||||
CodeGenContext, CodeGenTask,
|
CodeGenContext, CodeGenTask,
|
||||||
},
|
},
|
||||||
|
@ -23,7 +24,7 @@ use crate::{
|
||||||
toplevel::{
|
toplevel::{
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
helper::PRIMITIVE_DEF_IDS,
|
helper::PRIMITIVE_DEF_IDS,
|
||||||
numpy::make_ndarray_ty,
|
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
|
||||||
TopLevelDef,
|
TopLevelDef,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
|
@ -1128,6 +1129,41 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
Some("f_pow_i")
|
Some("f_pow_i")
|
||||||
);
|
);
|
||||||
Ok(Some(res.into()))
|
Ok(Some(res.into()))
|
||||||
|
} else if ty1 == ty2 && matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
let (ndarray_dtype, _) = unpack_ndarray_tvars(&mut ctx.unifier, ty1);
|
||||||
|
|
||||||
|
let left_val = NDArrayValue::from_ptr_val(
|
||||||
|
left_val.into_pointer_value(),
|
||||||
|
llvm_usize,
|
||||||
|
None
|
||||||
|
);
|
||||||
|
let right_val = NDArrayValue::from_ptr_val(
|
||||||
|
right_val.into_pointer_value(),
|
||||||
|
llvm_usize,
|
||||||
|
None
|
||||||
|
);
|
||||||
|
let res = numpy::ndarray_elementwise_binop_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ndarray_dtype,
|
||||||
|
if is_aug_assign { Some(left_val) } else { None },
|
||||||
|
left_val,
|
||||||
|
right_val,
|
||||||
|
|generator, ctx, elem_ty, (lhs, rhs)| {
|
||||||
|
gen_binop_expr_with_values(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
(&Some(elem_ty), lhs),
|
||||||
|
op,
|
||||||
|
(&Some(elem_ty), rhs),
|
||||||
|
ctx.current_loc,
|
||||||
|
is_aug_assign,
|
||||||
|
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Some(res.as_ptr_value().into()))
|
||||||
} else {
|
} else {
|
||||||
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
||||||
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {
|
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {
|
||||||
|
|
|
@ -17,6 +17,7 @@ use crate::{
|
||||||
CodeGenContext,
|
CodeGenContext,
|
||||||
CodeGenerator,
|
CodeGenerator,
|
||||||
irrt::{
|
irrt::{
|
||||||
|
call_ndarray_calc_broadcast,
|
||||||
call_ndarray_calc_nd_indices,
|
call_ndarray_calc_nd_indices,
|
||||||
call_ndarray_calc_size,
|
call_ndarray_calc_size,
|
||||||
},
|
},
|
||||||
|
@ -343,6 +344,42 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates the LLVM IR for populating the entire `NDArray` using a lambda with the same-indexed
|
||||||
|
/// element from two other `NDArray` as its input.
|
||||||
|
fn ndarray_fill_zip_map_flattened<'ctx, G, ValueFn>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
res: NDArrayValue<'ctx>,
|
||||||
|
lhs: NDArrayValue<'ctx>,
|
||||||
|
rhs: NDArrayValue<'ctx>,
|
||||||
|
value_fn: ValueFn,
|
||||||
|
) -> Result<NDArrayValue<'ctx>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator + ?Sized,
|
||||||
|
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
ndarray_fill_flattened(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
res,
|
||||||
|
|generator, ctx, idx| {
|
||||||
|
let elem = unsafe {
|
||||||
|
(
|
||||||
|
lhs.data().get_unchecked(ctx, generator, idx, None),
|
||||||
|
rhs.data().get_unchecked(ctx, generator, idx, None),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
debug_assert_eq!(elem.0.get_type(), elem.1.get_type());
|
||||||
|
|
||||||
|
value_fn(generator, ctx, elem_ty, elem)
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
|
||||||
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
|
@ -581,6 +618,63 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
Ok(ndarray)
|
Ok(ndarray)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// LLVM-typed implementation for computing elementwise binary operations.
|
||||||
|
///
|
||||||
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
|
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
||||||
|
/// written to a new `ndarray`.
|
||||||
|
/// * `value_fn` - Function mapping the two input elements into the result.
|
||||||
|
pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
elem_ty: Type,
|
||||||
|
res: Option<NDArrayValue<'ctx>>,
|
||||||
|
this: NDArrayValue<'ctx>,
|
||||||
|
other: NDArrayValue<'ctx>,
|
||||||
|
value_fn: ValueFn,
|
||||||
|
) -> Result<NDArrayValue<'ctx>, String>
|
||||||
|
where
|
||||||
|
G: CodeGenerator,
|
||||||
|
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
|
||||||
|
{
|
||||||
|
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, this, other);
|
||||||
|
let ndarray = res.unwrap_or_else(|| {
|
||||||
|
create_ndarray_dyn_shape(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
&ndarray_dims,
|
||||||
|
|_, _, v| {
|
||||||
|
Ok(v.0)
|
||||||
|
},
|
||||||
|
|_, ctx, v, idx| {
|
||||||
|
unsafe {
|
||||||
|
let data_ptr = ctx.builder.build_in_bounds_gep(v.1, &[idx], "")
|
||||||
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
|
ctx.builder.build_load(data_ptr, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.map_err(|e| e.to_string())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
).unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
ndarray_fill_zip_map_flattened(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
elem_ty,
|
||||||
|
ndarray,
|
||||||
|
this,
|
||||||
|
other,
|
||||||
|
|generator, ctx, elem_ty, elems| {
|
||||||
|
value_fn(generator, ctx, elem_ty, elems)
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
/// Generates LLVM IR for `ndarray.empty`.
|
/// Generates LLVM IR for `ndarray.empty`.
|
||||||
pub fn gen_ndarray_empty<'ctx>(
|
pub fn gen_ndarray_empty<'ctx>(
|
||||||
context: &mut CodeGenContext<'ctx, '_>,
|
context: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
|
Loading…
Reference in New Issue