From 847615fc2f7d7736f3e66a00aea9e2584eda3975 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 19 Apr 2024 19:00:07 +0800 Subject: [PATCH] core: Implement numpy.matmul for 2D-2D ndarrays --- nac3core/src/codegen/expr.rs | 65 +++-- nac3core/src/codegen/numpy.rs | 332 +++++++++++++++++++++++- nac3core/src/codegen/stmt.rs | 19 +- nac3core/src/typecheck/magic_methods.rs | 45 +++- nac3standalone/demo/src/ndarray.py | 15 ++ 5 files changed, 434 insertions(+), 42 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index e29312115..8812c8026 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -384,7 +384,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { rhs: BasicValueEnum<'ctx>, ) -> BasicValueEnum<'ctx> { let (BasicValueEnum::FloatValue(lhs), BasicValueEnum::FloatValue(rhs)) = (lhs, rhs) else { - unreachable!() + unreachable!("Expected (FloatValue, FloatValue), got ({}, {})", lhs.get_type(), rhs.get_type()) }; match op { Operator::Add => self.builder.build_float_add(lhs, rhs, "fadd").map(Into::into).unwrap(), @@ -589,8 +589,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { // even if this assumption is violated, it does not matter as exception unwinding is // slow anyway... let cond = call_expect(self, cond, i1_true, Some("expect")); - let current_fun = self.builder.get_insert_block().unwrap().get_parent().unwrap(); - let then_block = self.ctx.append_basic_block(current_fun, "succ"); + let current_bb = self.builder.get_insert_block().unwrap(); + let current_fun = current_bb.get_parent().unwrap(); + let then_block = self.ctx.insert_basic_block_after(current_bb, "succ"); let exn_block = self.ctx.append_basic_block(current_fun, "fail"); self.builder.build_conditional_branch(cond, then_block, exn_block).unwrap(); self.builder.position_at_end(exn_block); @@ -1148,27 +1149,45 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( let left_val = NDArrayValue::from_ptr_val( left_val.into_pointer_value(), llvm_usize, - None + None, ); - let res = numpy::ndarray_elementwise_binop_impl( - generator, - ctx, - ndarray_dtype1, - if is_aug_assign { Some(left_val) } else { None }, - (left_val.as_ptr_value().into(), false), - (right_val, false), - |generator, ctx, (lhs, rhs)| { - gen_binop_expr_with_values( - generator, - ctx, - (&Some(ndarray_dtype1), lhs), - op, - (&Some(ndarray_dtype2), rhs), - ctx.current_loc, - is_aug_assign, - )?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype1) - }, - )?; + let right_val = NDArrayValue::from_ptr_val( + right_val.into_pointer_value(), + llvm_usize, + None, + ); + + let res = if *op == Operator::MatMult { + // MatMult is the only binop which is not an elementwise op + numpy::ndarray_matmul_2d( + generator, + ctx, + ndarray_dtype1, + if is_aug_assign { Some(left_val) } else { None }, + left_val, + right_val, + )? + } else { + numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + ndarray_dtype1, + if is_aug_assign { Some(left_val) } else { None }, + (left_val.as_ptr_value().into(), false), + (right_val.as_ptr_value().into(), false), + |generator, ctx, (lhs, rhs)| { + gen_binop_expr_with_values( + generator, + ctx, + (&Some(ndarray_dtype1), lhs), + op, + (&Some(ndarray_dtype2), rhs), + ctx.current_loc, + is_aug_assign, + )?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype1) + }, + )? + }; Ok(Some(res.as_ptr_value().into())) } else { diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 418f201a6..16036f3e8 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,9 +1,5 @@ -use inkwell::{ - IntPredicate, - types::BasicType, - values::{BasicValueEnum, IntValue, PointerValue} -}; -use nac3parser::ast::StrRef; +use inkwell::{IntPredicate, OptimizationLevel, types::BasicType, values::{BasicValueEnum, IntValue, PointerValue}}; +use nac3parser::ast::{Operator, StrRef}; use crate::{ codegen::{ classes::{ @@ -14,17 +10,20 @@ use crate::{ TypedArrayLikeAccessor, TypedArrayLikeAdapter, UntypedArrayLikeAccessor, + UntypedArrayLikeMutator, }, CodeGenContext, CodeGenerator, + expr::gen_binop_expr_with_values, irrt::{ call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size, }, - llvm_intrinsics::call_memcpy_generic, - stmt::gen_for_callback_incrementing, + llvm_intrinsics, + llvm_intrinsics::{call_memcpy_generic}, + stmt::{gen_for_callback_incrementing, gen_if_else_expr_callback}, }, symbol_resolver::ValueEnum, toplevel::{ @@ -85,6 +84,8 @@ fn create_ndarray_dyn_shape<'ctx, 'a, G, V, LenFn, DataFn>( [None, None, None], ctx.current_loc, ); + + // TODO: Disallow dim_sz > u32_MAX Ok(()) }, @@ -171,6 +172,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>( [None, None, None], ctx.current_loc, ); + + // TODO: Disallow dim_sz > u32_MAX } let ndarray = generator.gen_var_alloc( @@ -824,6 +827,319 @@ pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>( Ok(ndarray) } +/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. +/// +/// * `elem_ty` - The element type of the `NDArray`. +/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be +/// written to a new `ndarray`. +pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + res: Option>, + lhs: NDArrayValue<'ctx>, + rhs: NDArrayValue<'ctx>, +) -> Result, String> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + if cfg!(debug_assertions) { + let lhs_ndims = lhs.load_ndims(ctx); + let rhs_ndims = rhs.load_ndims(ctx); + + // lhs.ndims == 2 + ctx.make_assert( + generator, + ctx.builder.build_int_compare( + IntPredicate::EQ, + lhs_ndims, + llvm_usize.const_int(2, false), + "", + ).unwrap(), + "0:ValueError", + "", + [None, None, None], + ctx.current_loc, + ); + + // rhs.ndims == 2 + ctx.make_assert( + generator, + ctx.builder.build_int_compare( + IntPredicate::EQ, + rhs_ndims, + llvm_usize.const_int(2, false), + "", + ).unwrap(), + "0:ValueError", + "", + [None, None, None], + ctx.current_loc, + ); + + if let Some(res) = res { + let res_ndims = res.load_ndims(ctx); + let res_dim0 = unsafe { + res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + }; + let res_dim1 = unsafe { + res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + }; + let lhs_dim0 = unsafe { + lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + }; + let rhs_dim1 = unsafe { + rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + }; + + // res.ndims == 2 + ctx.make_assert( + generator, + ctx.builder.build_int_compare( + IntPredicate::EQ, + res_ndims, + llvm_usize.const_int(2, false), + "", + ).unwrap(), + "0:ValueError", + "", + [None, None, None], + ctx.current_loc, + ); + + // res.dims[0] == lhs.dims[0] + ctx.make_assert( + generator, + ctx.builder.build_int_compare( + IntPredicate::EQ, + lhs_dim0, + res_dim0, + "", + ).unwrap(), + "0:ValueError", + "", + [None, None, None], + ctx.current_loc, + ); + + // res.dims[1] == rhs.dims[0] + ctx.make_assert( + generator, + ctx.builder.build_int_compare( + IntPredicate::EQ, + rhs_dim1, + res_dim1, + "", + ).unwrap(), + "0:ValueError", + "", + [None, None, None], + ctx.current_loc, + ); + } + } + + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { + let lhs_dim1 = unsafe { + lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None) + }; + let rhs_dim0 = unsafe { + rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) + }; + + // lhs.dims[1] == rhs.dims[0] + ctx.make_assert( + generator, + ctx.builder.build_int_compare( + IntPredicate::EQ, + lhs_dim1, + rhs_dim0, + "", + ).unwrap(), + "0:ValueError", + "", + [None, None, None], + ctx.current_loc, + ); + } + + let lhs = if res.is_some_and(|res| res.as_ptr_value() == lhs.as_ptr_value()) { + ndarray_copy_impl(generator, ctx, elem_ty, lhs)? + } else { + lhs + }; + + let ndarray = res.unwrap_or_else(|| { + create_ndarray_dyn_shape( + generator, + ctx, + elem_ty, + &(lhs, rhs), + |_, _, _| { + Ok(llvm_usize.const_int(2, false)) + }, + |generator, ctx, (lhs, rhs), idx| { + gen_if_else_expr_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx.builder.build_int_compare( + IntPredicate::EQ, + idx, + llvm_usize.const_zero(), + "", + ).unwrap()) + }, + |generator, ctx| { + Ok(Some(unsafe { + lhs.dim_sizes().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_zero(), + None, + ) + })) + }, + |generator, ctx| { + Ok(Some(unsafe { + rhs.dim_sizes().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(1, false), + None, + ) + })) + }, + ).map(|v| v.map(BasicValueEnum::into_int_value).unwrap()) + }, + ).unwrap() + }); + + let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); + + ndarray_fill_indexed( + generator, + ctx, + ndarray, + |generator, ctx, idx| { + llvm_intrinsics::call_expect( + ctx, + idx.size(ctx, generator).get_type().const_int(2, false), + idx.size(ctx, generator), + None, + ); + + let common_dim = { + let lhs_idx1 = unsafe { + lhs.dim_sizes().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(1, false), + None, + ) + }; + let rhs_idx0 = unsafe { + rhs.dim_sizes().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_zero(), + None, + ) + }; + + let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); + + ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() + }; + + let idx0 = unsafe { + let idx0 = idx.get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_zero(), + None, + ); + + ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() + }; + let idx1 = unsafe { + let idx1 = idx.get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_int(1, false), + None, + ); + + ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() + }; + + let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; + let result_identity = ndarray_zero_value(generator, ctx, elem_ty); + ctx.builder.build_store(result_addr, result_identity).unwrap(); + + gen_for_callback_incrementing( + generator, + ctx, + llvm_i32.const_zero(), + (common_dim, false), + |generator, ctx, i| { + let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); + + let ab_idx = generator.gen_array_var_alloc( + ctx, + llvm_i32.into(), + llvm_usize.const_int(2, false), + None, + )?; + + let a = unsafe { + ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into()); + ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into()); + + lhs.data().get_unchecked(ctx, generator, &ab_idx, None) + }; + let b = unsafe { + ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into()); + ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), idx1.into()); + + rhs.data().get_unchecked(ctx, generator, &ab_idx, None) + }; + + let a_mul_b = gen_binop_expr_with_values( + generator, + ctx, + (&Some(elem_ty), a), + &Operator::Mult, + (&Some(elem_ty), b), + ctx.current_loc, + false, + )?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?; + + let result = ctx.builder.build_load(result_addr, "").unwrap(); + let result = gen_binop_expr_with_values( + generator, + ctx, + (&Some(elem_ty), result), + &Operator::Add, + (&Some(elem_ty), a_mul_b), + ctx.current_loc, + false, + )?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?; + ctx.builder.build_store(result_addr, result).unwrap(); + + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + + let result = ctx.builder.build_load(result_addr, "").unwrap(); + Ok(result) + } + )?; + + Ok(ndarray) +} + /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx>( context: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index fe9d61776..ae0e00984 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -495,14 +495,14 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, { - let current = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap(); - let init_bb = ctx.ctx.append_basic_block(current, "for.init"); + let current_bb = ctx.builder.get_insert_block().unwrap(); + let init_bb = ctx.ctx.insert_basic_block_after(current_bb, "for.init"); // The BB containing the loop condition check - let cond_bb = ctx.ctx.append_basic_block(current, "for.cond"); - let body_bb = ctx.ctx.append_basic_block(current, "for.body"); + let cond_bb = ctx.ctx.insert_basic_block_after(init_bb, "for.cond"); + let body_bb = ctx.ctx.insert_basic_block_after(cond_bb, "for.body"); // The BB containing the increment expression - let update_bb = ctx.ctx.append_basic_block(current, "for.update"); - let cont_bb = ctx.ctx.append_basic_block(current, "for.end"); + let update_bb = ctx.ctx.insert_basic_block_after(body_bb, "for.update"); + let cont_bb = ctx.ctx.insert_basic_block_after(update_bb, "for.end"); // store loop bb information and restore it later let loop_bb = ctx.loop_target.replace((update_bb, cont_bb)); @@ -719,12 +719,10 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>( R: BasicValue<'ctx>, { let current_bb = ctx.builder.get_insert_block().unwrap(); - let current_fn = current_bb.get_parent().unwrap(); - - let end_bb = ctx.ctx.append_basic_block(current_fn, "if.end"); let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.then"); - let else_bb = ctx.ctx.insert_basic_block_after(current_bb, "if.else"); + let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "if.else"); + let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "if.end"); let cond = cond_fn(generator, ctx)?; assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width()); @@ -742,6 +740,7 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>( ctx.builder.build_unconditional_branch(end_bb).unwrap(); } + ctx.builder.position_at_end(end_bb); let phi = match (then_val, else_val) { (Some(tv), Some(ev)) => { let tv_ty = tv.as_basic_value_enum().get_type(); diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index dc7afe5c2..fd38c02cb 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -291,6 +291,17 @@ pub fn impl_mod( impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]); } +/// [Operator::MatMult] +pub fn impl_matmul( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, + other_ty: &[Type], + ret_ty: Option, +) { + impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::MatMult]) +} + /// `UAdd`, `USub` pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option) { impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]); @@ -431,7 +442,38 @@ pub fn typeof_binop( } } - Operator::MatMult => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, + Operator::MatMult => { + let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); + let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { + TypeEnum::TLiteral { values, .. } => { + assert_eq!(values.len(), 1); + u64::try_from(values[0].clone()).unwrap() + } + _ => unreachable!(), + }; + let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); + let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { + TypeEnum::TLiteral { values, .. } => { + assert_eq!(values.len(), 1); + u64::try_from(values[0].clone()).unwrap() + } + _ => unreachable!(), + }; + + match (lhs_ndims, rhs_ndims) { + (2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, + (lhs, rhs) if lhs == 0 || rhs == 0 => { + return Err(format!( + "Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", + (rhs == 0) as u8 + )) + } + (lhs, rhs) => { + return Err(format!("ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported")) + } + } + } + Operator::Div => { if is_left_ndarray || is_right_ndarray { typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)? @@ -610,6 +652,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); + impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t)); impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 371bdb480..eafe39d81 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -429,6 +429,19 @@ def test_ndarray_ipow_broadcast_scalar(): output_ndarray_float_2(x) +def test_ndarray_matmul(): + x = np_identity(2) + y = x @ np_ones([2, 2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_imatmul(): + x = np_identity(2) + x @= np_ones([2, 2]) + + output_ndarray_float_2(x) + def test_ndarray_pos(): x_int32 = np_full([2, 2], -2) y_int32 = +x_int32 @@ -696,6 +709,8 @@ def run() -> int32: test_ndarray_ipow() test_ndarray_ipow_broadcast() test_ndarray_ipow_broadcast_scalar() + test_ndarray_matmul() + test_ndarray_imatmul() test_ndarray_pos() test_ndarray_neg() test_ndarray_inv()