From 59f19e29df1f9765ff6b83cffd3a9ca5d450b145 Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 19 Dec 2024 10:25:35 +0800 Subject: [PATCH] [core] codegen: Reimplement ndarray binop Based on 9e40c834: core/ndstrides: implement binop --- nac3core/src/codegen/expr.rs | 163 +++++++++++++++++------------------ 1 file changed, 80 insertions(+), 83 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 8d7f8e35..8225df7c 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -34,14 +34,19 @@ use super::{ }, types::{ndarray::NDArrayType, ListType}, values::{ - ndarray::RustNDIndex, ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, + ndarray::{NDArrayOut, RustNDIndex, ScalarOrNDArray}, + ArrayLikeIndexer, ArrayLikeValue, ListValue, ProxyValue, RangeValue, UntypedArrayLikeAccessor, }, CodeGenContext, CodeGenTask, CodeGenerator, }; use crate::{ symbol_resolver::{SymbolValue, ValueEnum}, - toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef}, + toplevel::{ + helper::{arraylike_flatten_element_type, PrimDef}, + numpy::unpack_ndarray_var_tys, + DefinitionId, TopLevelDef, + }, typecheck::{ magic_methods::{Binop, BinopVariant, HasOpInfo}, typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap}, @@ -1526,98 +1531,90 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { - let is_ndarray1 = ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); - let is_ndarray2 = ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()); + let left = ScalarOrNDArray::from_value(generator, ctx, (ty1, left_val)); + let right = ScalarOrNDArray::from_value(generator, ctx, (ty2, right_val)); - if is_ndarray1 && is_ndarray2 { + if op.base == Operator::MatMult { let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); - let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); - assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); - let left_val = NDArrayType::from_unifier_type(generator, ctx, ty1) - .map_value(left_val.into_pointer_value(), None); - let right_val = NDArrayType::from_unifier_type(generator, ctx, ty2) - .map_value(right_val.into_pointer_value(), None); - - let res = if op.base == Operator::MatMult { - // MatMult is the only binop which is not an elementwise op - numpy::ndarray_matmul_2d( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left_val), - }, - left_val, - right_val, - )? - } else { - numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ndarray_dtype1, - match op.variant { - BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(left_val), - }, - (ty1, left_val.as_base_value().into(), false), - (ty2, right_val.as_base_value().into(), false), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( - generator, - ctx, - (&Some(ndarray_dtype1), lhs), - op, - (&Some(ndarray_dtype2), rhs), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum( - ctx, - generator, - ndarray_dtype1, - ) - }, - )? - }; - - Ok(Some(res.as_base_value().into())) - } else { - let (ndarray_dtype, _) = - unpack_ndarray_var_tys(&mut ctx.unifier, if is_ndarray1 { ty1 } else { ty2 }); - let ndarray_val = - NDArrayType::from_unifier_type(generator, ctx, if is_ndarray1 { ty1 } else { ty2 }) - .map_value( - if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), - None, - ); - let res = numpy::ndarray_elementwise_binop_impl( + // MatMult is the only binop which is not an elementwise op + let result = numpy::ndarray_matmul_2d( generator, ctx, - ndarray_dtype, + ndarray_dtype1, match op.variant { BinopVariant::Normal => None, - BinopVariant::AugAssign => Some(ndarray_val), - }, - (ty1, left_val, !is_ndarray1), - (ty2, right_val, !is_ndarray2), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( - generator, - ctx, - (&Some(ndarray_dtype), lhs), - op, - (&Some(ndarray_dtype), rhs), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, ndarray_dtype) + BinopVariant::AugAssign => Some(left), }, + left, + right, )?; - Ok(Some(res.as_base_value().into())) + Ok(Some(result.as_base_value().into())) + } else { + // For other operations, they are all elementwise operations. + + // There are only three cases: + // - LHS is a scalar, RHS is an ndarray. + // - LHS is an ndarray, RHS is a scalar. + // - LHS is an ndarray, RHS is an ndarray. + // + // For all cases, the scalar operand is promoted to an ndarray, + // the two are then broadcasted, and starmapped through. + + let ty1_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty1); + let ty2_dtype = arraylike_flatten_element_type(&mut ctx.unifier, ty2); + + // Inhomogeneous binary operations are not supported. + assert!(ctx.unifier.unioned(ty1_dtype, ty2_dtype)); + + let common_dtype = ty1_dtype; + let llvm_common_dtype = left.get_dtype(); + + let out = match op.variant { + BinopVariant::Normal => NDArrayOut::NewNDArray { dtype: llvm_common_dtype }, + BinopVariant::AugAssign => { + // If this is an augmented assignment. + // `left` has to be an ndarray. If it were a scalar then NAC3 simply doesn't support it. + if let ScalarOrNDArray::NDArray(out_ndarray) = left { + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } + } else { + panic!("left must be an ndarray") + } + } + }; + + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); + + let result = NDArrayType::new_broadcast( + generator, + ctx.ctx, + llvm_common_dtype, + &[left.get_type(), right.get_type()], + ) + .broadcast_starmap(generator, ctx, &[left, right], out, |generator, ctx, scalars| { + let left_value = scalars[0]; + let right_value = scalars[1]; + + let result = gen_binop_expr_with_values( + generator, + ctx, + (&Some(ty1_dtype), left_value), + op, + (&Some(ty2_dtype), right_value), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, common_dtype)?; + + Ok(result) + }) + .unwrap(); + Ok(Some(result.as_base_value().into())) } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());