forked from M-Labs/nac3
[core] codegen: Reimplement ndarray binop
Based on 9e40c834
: core/ndstrides: implement binop
This commit is contained in:
parent
6cbba8fdde
commit
59f19e29df
@ -34,14 +34,19 @@ use super::{
|
||||
},
|
||||
types::{ndarray::NDArrayType, ListType},
|
||||
values::{
|
||||
ndarray::RustNDIndex, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue,
|
||||
ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray},
|
||||
ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue,
|
||||
UntypedArrayLikeAccessor,
|
||||
},
|
||||
CodeGenContext, CodeGenTask, CodeGenerator,
|
||||
};
|
||||
use crate::{
|
||||
symbol_resolver::{SymbolValue, ValueEnum},
|
||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
||||
toplevel::{
|
||||
helper::{arraylike_flatten_element_type, PrimDef},
|
||||
numpy::unpack_ndarray_var_tys,
|
||||
DefinitionId, TopLevelDef,
|
||||
},
|
||||
typecheck::{
|
||||
magic_methods::{Binop, BinopVariant, HasOpInfo},
|
||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||
@ -1526,98 +1531,90 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||
{
|
||||
let is_ndarray1 = ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
let left = ScalarOrNDArray::from_value(generator, ctx, (ty1, left_val));
|
||||
let right = ScalarOrNDArray::from_value(generator, ctx, (ty2, right_val));
|
||||
|
||||
if is_ndarray1 && is_ndarray2 {
|
||||
if op.base == Operator::MatMult {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
|
||||
|
||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
let left = left.to_ndarray(generator, ctx);
|
||||
let right = right.to_ndarray(generator, ctx);
|
||||
|
||||
let left_val = NDArrayType::from_unifier_type(generator, ctx, ty1)
|
||||
.map_value(left_val.into_pointer_value(), None);
|
||||
let right_val = NDArrayType::from_unifier_type(generator, ctx, ty2)
|
||||
.map_value(right_val.into_pointer_value(), None);
|
||||
|
||||
let res = if op.base == Operator::MatMult {
|
||||
// MatMult is the only binop which is not an elementwise op
|
||||
numpy::ndarray_matmul_2d(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype1,
|
||||
match op.variant {
|
||||
BinopVariant::Normal => None,
|
||||
BinopVariant::AugAssign => Some(left_val),
|
||||
},
|
||||
left_val,
|
||||
right_val,
|
||||
)?
|
||||
} else {
|
||||
numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype1,
|
||||
match op.variant {
|
||||
BinopVariant::Normal => None,
|
||||
BinopVariant::AugAssign => Some(left_val),
|
||||
},
|
||||
(ty1, left_val.as_base_value().into(), false),
|
||||
(ty2, right_val.as_base_value().into(), false),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(ndarray_dtype1), lhs),
|
||||
op,
|
||||
(&Some(ndarray_dtype2), rhs),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(
|
||||
ctx,
|
||||
generator,
|
||||
ndarray_dtype1,
|
||||
)
|
||||
},
|
||||
)?
|
||||
};
|
||||
|
||||
Ok(Some(res.as_base_value().into()))
|
||||
} else {
|
||||
let (ndarray_dtype, _) =
|
||||
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
|
||||
let ndarray_val =
|
||||
NDArrayType::from_unifier_type(generator, ctx, if is_ndarray1 { ty1 } else { ty2 })
|
||||
.map_value(
|
||||
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
||||
None,
|
||||
);
|
||||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
// MatMult is the only binop which is not an elementwise op
|
||||
let result = numpy::ndarray_matmul_2d(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype,
|
||||
ndarray_dtype1,
|
||||
match op.variant {
|
||||
BinopVariant::Normal => None,
|
||||
BinopVariant::AugAssign => Some(ndarray_val),
|
||||
},
|
||||
(ty1, left_val, !is_ndarray1),
|
||||
(ty2, right_val, !is_ndarray2),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(ndarray_dtype), lhs),
|
||||
op,
|
||||
(&Some(ndarray_dtype), rhs),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
||||
BinopVariant::AugAssign => Some(left),
|
||||
},
|
||||
left,
|
||||
right,
|
||||
)?;
|
||||
|
||||
Ok(Some(res.as_base_value().into()))
|
||||
Ok(Some(result.as_base_value().into()))
|
||||
} else {
|
||||
// For other operations, they are all elementwise operations.
|
||||
|
||||
// There are only three cases:
|
||||
// - LHS is a scalar, RHS is an ndarray.
|
||||
// - LHS is an ndarray, RHS is a scalar.
|
||||
// - LHS is an ndarray, RHS is an ndarray.
|
||||
//
|
||||
// For all cases, the scalar operand is promoted to an ndarray,
|
||||
// the two are then broadcasted, and starmapped through.
|
||||
|
||||
let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1);
|
||||
let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2);
|
||||
|
||||
// Inhomogeneous binary operations are not supported.
|
||||
assert!(ctx.unifier.unioned(ty1_dtype, ty2_dtype));
|
||||
|
||||
let common_dtype = ty1_dtype;
|
||||
let llvm_common_dtype = left.get_dtype();
|
||||
|
||||
let out = match op.variant {
|
||||
BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: llvm_common_dtype },
|
||||
BinopVariant::AugAssign => {
|
||||
// If this is an augmented assignment.
|
||||
// `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it.
|
||||
if let ScalarOrNDArray::NDArray(out_ndarray) = left {
|
||||
NDArrayOut::WriteToNDArray { ndarray: out_ndarray }
|
||||
} else {
|
||||
panic!("left must be an ndarray")
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let left = left.to_ndarray(generator, ctx);
|
||||
let right = right.to_ndarray(generator, ctx);
|
||||
|
||||
let result = NDArrayType::new_broadcast(
|
||||
generator,
|
||||
ctx.ctx,
|
||||
llvm_common_dtype,
|
||||
&[left.get_type(), right.get_type()],
|
||||
)
|
||||
.broadcast_starmap(generator, ctx, &[left, right], out, |generator, ctx, scalars| {
|
||||
let left_value = scalars[0];
|
||||
let right_value = scalars[1];
|
||||
|
||||
let result = gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(ty1_dtype), left_value),
|
||||
op,
|
||||
(&Some(ty2_dtype), right_value),
|
||||
ctx.current_loc,
|
||||
)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, common_dtype)?;
|
||||
|
||||
Ok(result)
|
||||
})
|
||||
.unwrap();
|
||||
Ok(Some(result.as_base_value().into()))
|
||||
}
|
||||
} else {
|
||||
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
||||
|
Loading…
Reference in New Issue
Block a user