core/ndstrides: implement binop

This commit is contained in:
lyken 2024-08-25 00:04:10 +08:00 committed by David Mak
parent fbfc0b293a
commit 9e40c83490

View File

@ -8,7 +8,10 @@ use std::{
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
types::{AnyType, BasicType, BasicTypeEnum}, types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue, StructValue}, values::{
BasicValue, BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue,
StructValue,
},
AddressSpace, IntPredicate, OptimizationLevel, AddressSpace, IntPredicate, OptimizationLevel,
}; };
use itertools::{chain, izip, Either, Itertools}; use itertools::{chain, izip, Either, Itertools};
@ -34,7 +37,10 @@ use super::{
need_sret, numpy, need_sret, numpy,
object::{ object::{
any::AnyObject, any::AnyObject,
ndarray::{indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject}, ndarray::{
indexing::util::gen_ndarray_subscript_ndindices, NDArrayObject, NDArrayOut,
ScalarOrNDArray,
},
}, },
stmt::{ stmt::{
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise, gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
@ -1549,99 +1555,71 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|| ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{ {
let llvm_usize = generator.get_size_type(ctx.ctx); let left =
ScalarOrNDArray::split_object(generator, ctx, AnyObject { ty: ty1, value: left_val });
let right =
ScalarOrNDArray::split_object(generator, ctx, AnyObject { ty: ty2, value: right_val });
let is_ndarray1 = ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); // Inhomogeneous binary operations are not supported.
let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); assert!(ctx.unifier.unioned(left.get_dtype(), right.get_dtype()));
if is_ndarray1 && is_ndarray2 { let common_dtype = left.get_dtype();
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); let out = match op.variant {
BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: common_dtype },
let left_val = BinopVariant::AugAssign => {
NDArrayValue::from_ptr_val(left_val.into_pointer_value(), llvm_usize, None); // If this is an augmented assignment.
let right_val = // `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it.
NDArrayValue::from_ptr_val(right_val.into_pointer_value(), llvm_usize, None); if let ScalarOrNDArray::NDArray(out_ndarray) = left {
NDArrayOut::WriteToNDArray { ndarray: out_ndarray }
let res = if op.base == Operator::MatMult {
// MatMult is the only binop which is not an elementwise op
numpy::ndarray_matmul_2d(
generator,
ctx,
ndarray_dtype1,
match op.variant {
BinopVariant::Normal => None,
BinopVariant::AugAssign => Some(left_val),
},
left_val,
right_val,
)?
} else { } else {
numpy::ndarray_elementwise_binop_impl( panic!("left must be an ndarray")
generator, }
ctx, }
ndarray_dtype1,
match op.variant {
BinopVariant::Normal => None,
BinopVariant::AugAssign => Some(left_val),
},
(left_val.as_base_value().into(), false),
(right_val.as_base_value().into(), false),
|generator, ctx, (lhs, rhs)| {
gen_binop_expr_with_values(
generator,
ctx,
(&Some(ndarray_dtype1), lhs),
op,
(&Some(ndarray_dtype2), rhs),
ctx.current_loc,
)?
.unwrap()
.to_basic_value_enum(
ctx,
generator,
ndarray_dtype1,
)
},
)?
}; };
Ok(Some(res.as_base_value().into())) if op.base == Operator::MatMult {
// Handle matrix multiplication.
todo!()
} else { } else {
let (ndarray_dtype, _) = // For other operations, they are all elementwise operations.
unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 });
let ndarray_val = NDArrayValue::from_ptr_val( // There are only three cases:
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), // - LHS is a scalar, RHS is an ndarray.
llvm_usize, // - LHS is an ndarray, RHS is a scalar.
None, // - LHS is an ndarray, RHS is an ndarray.
); //
let res = numpy::ndarray_elementwise_binop_impl( // For all cases, the scalar operand is promoted to an ndarray,
// the two are then broadcasted, and starmapped through.
let left = left.to_ndarray(generator, ctx);
let right = right.to_ndarray(generator, ctx);
let result = NDArrayObject::broadcast_starmap(
generator, generator,
ctx, ctx,
ndarray_dtype, &[left, right],
match op.variant { out,
BinopVariant::Normal => None, |generator, ctx, scalars| {
BinopVariant::AugAssign => Some(ndarray_val), let left_value = scalars[0];
}, let right_value = scalars[1];
(left_val, !is_ndarray1),
(right_val, !is_ndarray2), let result = gen_binop_expr_with_values(
|generator, ctx, (lhs, rhs)| {
gen_binop_expr_with_values(
generator, generator,
ctx, ctx,
(&Some(ndarray_dtype), lhs), (&Some(left.dtype), left_value),
op, op,
(&Some(ndarray_dtype), rhs), (&Some(right.dtype), right_value),
ctx.current_loc, ctx.current_loc,
)? )?
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, ndarray_dtype) .to_basic_value_enum(ctx, generator, common_dtype)?;
},
)?;
Ok(Some(res.as_base_value().into())) Ok(result)
},
)
.unwrap();
Ok(Some(ValueEnum::Dynamic(result.instance.value.as_basic_value_enum())))
} }
} else { } else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());