forked from M-Labs/nac3
core/ndstrides: implement binop
This commit is contained in:
parent
cb8cea4286
commit
f5698a9eed
|
@ -12,6 +12,7 @@ use crate::{
|
||||||
call_int_umin, call_memcpy_generic,
|
call_int_umin, call_memcpy_generic,
|
||||||
},
|
},
|
||||||
need_sret, numpy,
|
need_sret, numpy,
|
||||||
|
object::ndarray::{NDArrayOut, ScalarOrNDArray},
|
||||||
stmt::{
|
stmt::{
|
||||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||||
gen_var,
|
gen_var,
|
||||||
|
@ -28,7 +29,10 @@ use crate::{
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
attributes::{Attribute, AttributeLoc},
|
attributes::{Attribute, AttributeLoc},
|
||||||
types::{AnyType, BasicType, BasicTypeEnum},
|
types::{AnyType, BasicType, BasicTypeEnum},
|
||||||
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue, StructValue},
|
values::{
|
||||||
|
BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue,
|
||||||
|
StructValue,
|
||||||
|
},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
use itertools::{chain, izip, Either, Itertools};
|
use itertools::{chain, izip, Either, Itertools};
|
||||||
|
@ -1544,99 +1548,71 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let left =
|
||||||
|
ScalarOrNDArray::split_object(generator, ctx, AnyObject { ty: ty1, value: left_val });
|
||||||
|
let right =
|
||||||
|
ScalarOrNDArray::split_object(generator, ctx, AnyObject { ty: ty2, value: right_val });
|
||||||
|
|
||||||
let is_ndarray1 = ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
// Inhomogeneous binary operations are not supported.
|
||||||
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
assert!(ctx.unifier.unioned(left.get_dtype(), right.get_dtype()));
|
||||||
|
|
||||||
if is_ndarray1 && is_ndarray2 {
|
let common_dtype = left.get_dtype();
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
|
|
||||||
|
|
||||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
let out = match op.variant {
|
||||||
|
BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: common_dtype },
|
||||||
let left_val =
|
BinopVariant::AugAssign => {
|
||||||
NDArrayValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None);
|
// If this is an augmented assignment.
|
||||||
let right_val =
|
// `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it.
|
||||||
NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None);
|
if let ScalarOrNDArray::NDArray(out_ndarray) = left {
|
||||||
|
NDArrayOut::WriteToNDArray { ndarray: out_ndarray }
|
||||||
let res = if op.base == Operator::MatMult {
|
|
||||||
// MatMult is the only binop which is not an elementwise op
|
|
||||||
numpy::ndarray_matmul_2d(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
ndarray_dtype1,
|
|
||||||
match op.variant {
|
|
||||||
BinopVariant::Normal => None,
|
|
||||||
BinopVariant::AugAssign => Some(left_val),
|
|
||||||
},
|
|
||||||
left_val,
|
|
||||||
right_val,
|
|
||||||
)?
|
|
||||||
} else {
|
} else {
|
||||||
numpy::ndarray_elementwise_binop_impl(
|
panic!("left must be an ndarray")
|
||||||
generator,
|
}
|
||||||
ctx,
|
}
|
||||||
ndarray_dtype1,
|
|
||||||
match op.variant {
|
|
||||||
BinopVariant::Normal => None,
|
|
||||||
BinopVariant::AugAssign => Some(left_val),
|
|
||||||
},
|
|
||||||
(left_val.as_base_value().into(), false),
|
|
||||||
(right_val.as_base_value().into(), false),
|
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|
||||||
gen_binop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(&Some(ndarray_dtype1), lhs),
|
|
||||||
op,
|
|
||||||
(&Some(ndarray_dtype2), rhs),
|
|
||||||
ctx.current_loc,
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
ndarray_dtype1,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
)?
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
if op.base == Operator::MatMult {
|
||||||
|
// Handle matrix multiplication.
|
||||||
|
todo!()
|
||||||
} else {
|
} else {
|
||||||
let (ndarray_dtype, _) =
|
// For other operations, they are all elementwise operations.
|
||||||
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
|
|
||||||
let ndarray_val = NDArrayValue::from_ptr_val(
|
// There are only three cases:
|
||||||
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
// - LHS is a scalar, RHS is an ndarray.
|
||||||
llvm_usize,
|
// - LHS is an ndarray, RHS is a scalar.
|
||||||
None,
|
// - LHS is an ndarray, RHS is an ndarray.
|
||||||
);
|
//
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
// For all cases, the scalar operand is promoted to an ndarray,
|
||||||
|
// the two are then broadcasted, and starmapped through.
|
||||||
|
|
||||||
|
let left = left.to_ndarray(generator, ctx);
|
||||||
|
let right = right.to_ndarray(generator, ctx);
|
||||||
|
|
||||||
|
let result = NDArrayObject::broadcast_starmap(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ndarray_dtype,
|
&[left, right],
|
||||||
match op.variant {
|
out,
|
||||||
BinopVariant::Normal => None,
|
|generator, ctx, scalars| {
|
||||||
BinopVariant::AugAssign => Some(ndarray_val),
|
let left_value = scalars[0];
|
||||||
},
|
let right_value = scalars[1];
|
||||||
(left_val, !is_ndarray1),
|
|
||||||
(right_val, !is_ndarray2),
|
let result = gen_binop_expr_with_values(
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|
||||||
gen_binop_expr_with_values(
|
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
(&Some(ndarray_dtype), lhs),
|
(&Some(left.dtype), left_value),
|
||||||
op,
|
op,
|
||||||
(&Some(ndarray_dtype), rhs),
|
(&Some(right.dtype), right_value),
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
.to_basic_value_enum(ctx, generator, common_dtype)?;
|
||||||
},
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
Ok(result)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum())))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
||||||
|
|
Loading…
Reference in New Issue