From 73c2203b898440705b26f16a8d18b5b143e7d1db Mon Sep 17 00:00:00 2001 From: lyken Date: Sun, 25 Aug 2024 00:04:10 +0800 Subject: [PATCH 1/2] core/ndstrides: implement general matmul --- nac3core/irrt/irrt.cpp | 3 +- nac3core/irrt/irrt/ndarray/matmul.hpp | 100 ++++++ nac3core/src/codegen/expr.rs | 6 +- nac3core/src/codegen/irrt/mod.rs | 27 ++ nac3core/src/codegen/numpy.rs | 306 +----------------- nac3core/src/codegen/object/ndarray/matmul.rs | 216 +++++++++++++ nac3core/src/codegen/object/ndarray/mod.rs | 1 + ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/magic_methods.rs | 84 +++-- 13 files changed, 401 insertions(+), 356 deletions(-) create mode 100644 nac3core/irrt/irrt/ndarray/matmul.hpp create mode 100644 nac3core/src/codegen/object/ndarray/matmul.rs diff --git a/nac3core/irrt/irrt.cpp b/nac3core/irrt/irrt.cpp index 95856830..70ef2392 100644 --- a/nac3core/irrt/irrt.cpp +++ b/nac3core/irrt/irrt.cpp @@ -12,4 +12,5 @@ #include "irrt/ndarray/array.hpp" #include "irrt/ndarray/reshape.hpp" #include "irrt/ndarray/broadcast.hpp" -#include "irrt/ndarray/transpose.hpp" \ No newline at end of file +#include "irrt/ndarray/transpose.hpp" +#include "irrt/ndarray/matmul.hpp" \ No newline at end of file diff --git a/nac3core/irrt/irrt/ndarray/matmul.hpp b/nac3core/irrt/irrt/ndarray/matmul.hpp new file mode 100644 index 00000000..8c0a6a60 --- /dev/null +++ b/nac3core/irrt/irrt/ndarray/matmul.hpp @@ -0,0 +1,100 @@ +#pragma once + +#include "irrt/debug.hpp" +#include "irrt/exception.hpp" +#include "irrt/int_types.hpp" +#include "irrt/ndarray/basic.hpp" +#include "irrt/ndarray/broadcast.hpp" +#include "irrt/ndarray/iter.hpp" + +// NOTE: Everything would be much easier and elegant if einsum is implemented. + +namespace { +namespace ndarray { +namespace matmul { + +/** + * @brief Perform the broadcast in `np.einsum("...ij,...jk->...ik", a, b)`. + * + * Example: + * Suppose `a_shape == [1, 97, 4, 2]` + * and `b_shape == [99, 98, 1, 2, 5]`, + * + * ...then `new_a_shape == [99, 98, 97, 4, 2]`, + * `new_b_shape == [99, 98, 97, 2, 5]`, + * and `dst_shape == [99, 98, 97, 4, 5]`. + * ^^^^^^^^^^ ^^^^ + * (broadcasted) (4x2 @ 2x5 => 4x5) + * + * @param a_ndims Length of `a_shape`. + * @param a_shape Shape of `a`. + * @param b_ndims Length of `b_shape`. + * @param b_shape Shape of `b`. + * @param final_ndims Should be equal to `max(a_ndims, b_ndims)`. This is the length of `new_a_shape`, + * `new_b_shape`, and `dst_shape` - the number of dimensions after broadcasting. + */ +template +void calculate_shapes(SizeT a_ndims, + SizeT* a_shape, + SizeT b_ndims, + SizeT* b_shape, + SizeT final_ndims, + SizeT* new_a_shape, + SizeT* new_b_shape, + SizeT* dst_shape) { + debug_assert(SizeT, a_ndims >= 2); + debug_assert(SizeT, b_ndims >= 2); + debug_assert_eq(SizeT, max(a_ndims, b_ndims), final_ndims); + + // Check that a and b are compatible for matmul + if (a_shape[a_ndims - 1] != b_shape[b_ndims - 2]) { + // This is a custom error message. Different from NumPy. + raise_exception(SizeT, EXN_VALUE_ERROR, "Cannot multiply LHS (shape ?x{0}) with RHS (shape {1}x?})", + a_shape[a_ndims - 1], b_shape[b_ndims - 2], NO_PARAM); + } + + const SizeT num_entries = 2; + ShapeEntry entries[num_entries] = {{.ndims = a_ndims - 2, .shape = a_shape}, + {.ndims = b_ndims - 2, .shape = b_shape}}; + + // TODO: Optimize this + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_a_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, new_b_shape); + ndarray::broadcast::broadcast_shapes(num_entries, entries, final_ndims - 2, dst_shape); + + new_a_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + new_a_shape[final_ndims - 1] = a_shape[a_ndims - 1]; + new_b_shape[final_ndims - 2] = b_shape[b_ndims - 2]; + new_b_shape[final_ndims - 1] = b_shape[b_ndims - 1]; + dst_shape[final_ndims - 2] = a_shape[a_ndims - 2]; + dst_shape[final_ndims - 1] = b_shape[b_ndims - 1]; +} +} // namespace matmul +} // namespace ndarray +} // namespace + +extern "C" { +using namespace ndarray::matmul; + +void __nac3_ndarray_matmul_calculate_shapes(int32_t a_ndims, + int32_t* a_shape, + int32_t b_ndims, + int32_t* b_shape, + int32_t final_ndims, + int32_t* new_a_shape, + int32_t* new_b_shape, + int32_t* dst_shape) { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); +} + +void __nac3_ndarray_matmul_calculate_shapes64(int64_t a_ndims, + int64_t* a_shape, + int64_t b_ndims, + int64_t* b_shape, + int64_t final_ndims, + int64_t* new_a_shape, + int64_t* new_b_shape, + int64_t* dst_shape) { + calculate_shapes(a_ndims, a_shape, b_ndims, b_shape, final_ndims, new_a_shape, new_b_shape, dst_shape); +} +} \ No newline at end of file diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 9e3faf25..b111ea9a 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1580,7 +1580,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( if op.base == Operator::MatMult { // Handle matrix multiplication. - todo!() + let left = left.to_ndarray(generator, ctx); + let right = right.to_ndarray(generator, ctx); + let result = NDArrayObject::matmul(generator, ctx, left, right, out) + .split_unsized(generator, ctx); + Ok(Some(ValueEnum::Dynamic(result.to_basic_value_enum()))) } else { // For other operations, they are all elementwise operations. diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 77e39369..22c2b53d 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1220,3 +1220,30 @@ pub fn call_nac3_ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( .arg(axes) .returning_void(); } + +#[allow(clippy::too_many_arguments)] +pub fn call_nac3_ndarray_matmul_calculate_shapes<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a_ndims: Instance<'ctx, Int>, + a_shape: Instance<'ctx, Ptr>>, + b_ndims: Instance<'ctx, Int>, + b_shape: Instance<'ctx, Ptr>>, + final_ndims: Instance<'ctx, Int>, + new_a_shape: Instance<'ctx, Ptr>>, + new_b_shape: Instance<'ctx, Ptr>>, + dst_shape: Instance<'ctx, Ptr>>, +) { + let name = + get_sizet_dependent_function_name(generator, ctx, "__nac3_ndarray_matmul_calculate_shapes"); + FnCall::builder(generator, ctx, &name) + .arg(a_ndims) + .arg(a_shape) + .arg(b_ndims) + .arg(b_shape) + .arg(final_ndims) + .arg(new_a_shape) + .arg(new_b_shape) + .arg(dst_shape) + .returning_void(); +} diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 31002ee5..8d00c7cc 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -1,10 +1,10 @@ use inkwell::{ types::{AnyTypeEnum, BasicType, BasicTypeEnum, PointerType}, values::{BasicValue, BasicValueEnum, IntValue, PointerValue}, - AddressSpace, IntPredicate, OptimizationLevel, + AddressSpace, IntPredicate, }; -use nac3parser::ast::{Operator, StrRef}; +use nac3parser::ast::StrRef; use super::{ classes::{ @@ -12,7 +12,6 @@ use super::{ ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter, TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }, - expr::gen_binop_expr_with_values, irrt::{ calculate_len_for_slice_range, call_ndarray_calc_broadcast, call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size, @@ -34,10 +33,7 @@ use crate::{ numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, DefinitionId, }, - typecheck::{ - magic_methods::Binop, - typedef::{FunSignature, Type}, - }, + typecheck::typedef::{FunSignature, Type}, }; /// Creates an uninitialized `NDArray` instance. @@ -1437,302 +1433,6 @@ where Ok(ndarray) } -/// LLVM-typed implementation for computing matrix multiplication between two 2D `ndarray`s. -/// -/// * `elem_ty` - The element type of the `NDArray`. -/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be -/// written to a new `ndarray`. -pub fn ndarray_matmul_2d<'ctx, G: CodeGenerator>( - generator: &mut G, - ctx: &mut CodeGenContext<'ctx, '_>, - elem_ty: Type, - res: Option>, - lhs: NDArrayValue<'ctx>, - rhs: NDArrayValue<'ctx>, -) -> Result, String> { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - - if cfg!(debug_assertions) { - let lhs_ndims = lhs.load_ndims(ctx); - let rhs_ndims = rhs.load_ndims(ctx); - - // lhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, lhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // rhs.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare(IntPredicate::EQ, rhs_ndims, llvm_usize.const_int(2, false), "") - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - if let Some(res) = res { - let res_ndims = res.load_ndims(ctx); - let res_dim0 = unsafe { - res.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let res_dim1 = unsafe { - res.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let lhs_dim0 = unsafe { - lhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - let rhs_dim1 = unsafe { - rhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - - // res.ndims == 2 - ctx.make_assert( - generator, - ctx.builder - .build_int_compare( - IntPredicate::EQ, - res_ndims, - llvm_usize.const_int(2, false), - "", - ) - .unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[0] == lhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim0, res_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - - // res.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, rhs_dim1, res_dim1, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - } - - if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { - let lhs_dim1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_dim0 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - // lhs.dims[1] == rhs.dims[0] - ctx.make_assert( - generator, - ctx.builder.build_int_compare(IntPredicate::EQ, lhs_dim1, rhs_dim0, "").unwrap(), - "0:ValueError", - "", - [None, None, None], - ctx.current_loc, - ); - } - - let lhs = if res.is_some_and(|res| res.as_base_value() == lhs.as_base_value()) { - ndarray_copy_impl(generator, ctx, elem_ty, lhs)? - } else { - lhs - }; - - let ndarray = res.unwrap_or_else(|| { - create_ndarray_dyn_shape( - generator, - ctx, - elem_ty, - &(lhs, rhs), - |_, _, _| Ok(llvm_usize.const_int(2, false)), - |generator, ctx, (lhs, rhs), idx| { - gen_if_else_expr_callback( - generator, - ctx, - |_, ctx| { - Ok(ctx - .builder - .build_int_compare(IntPredicate::EQ, idx, llvm_usize.const_zero(), "") - .unwrap()) - }, - |generator, ctx| { - Ok(Some(unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - })) - }, - |generator, ctx| { - Ok(Some(unsafe { - rhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - })) - }, - ) - .map(|v| v.map(BasicValueEnum::into_int_value).unwrap()) - }, - ) - .unwrap() - }); - - let llvm_ndarray_ty = ctx.get_llvm_type(generator, elem_ty); - - ndarray_fill_indexed(generator, ctx, ndarray, |generator, ctx, idx| { - llvm_intrinsics::call_expect( - ctx, - idx.size(ctx, generator).get_type().const_int(2, false), - idx.size(ctx, generator), - None, - ); - - let common_dim = { - let lhs_idx1 = unsafe { - lhs.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - None, - ) - }; - let rhs_idx0 = unsafe { - rhs.dim_sizes().get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None) - }; - - let idx = llvm_intrinsics::call_expect(ctx, rhs_idx0, lhs_idx1, None); - - ctx.builder.build_int_truncate(idx, llvm_i32, "").unwrap() - }; - - let idx0 = unsafe { - let idx0 = idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_zero(), None); - - ctx.builder.build_int_truncate(idx0, llvm_i32, "").unwrap() - }; - let idx1 = unsafe { - let idx1 = - idx.get_typed_unchecked(ctx, generator, &llvm_usize.const_int(1, false), None); - - ctx.builder.build_int_truncate(idx1, llvm_i32, "").unwrap() - }; - - let result_addr = generator.gen_var_alloc(ctx, llvm_ndarray_ty, None)?; - let result_identity = ndarray_zero_value(generator, ctx, elem_ty); - ctx.builder.build_store(result_addr, result_identity).unwrap(); - - gen_for_callback_incrementing( - generator, - ctx, - None, - llvm_i32.const_zero(), - (common_dim, false), - |generator, ctx, _, i| { - let i = ctx.builder.build_int_truncate(i, llvm_i32, "").unwrap(); - - let ab_idx = generator.gen_array_var_alloc( - ctx, - llvm_i32.into(), - llvm_usize.const_int(2, false), - None, - )?; - - let a = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), idx0.into()); - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_int(1, false), i.into()); - - lhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - let b = unsafe { - ab_idx.set_unchecked(ctx, generator, &llvm_usize.const_zero(), i.into()); - ab_idx.set_unchecked( - ctx, - generator, - &llvm_usize.const_int(1, false), - idx1.into(), - ); - - rhs.data().get_unchecked(ctx, generator, &ab_idx, None) - }; - - let a_mul_b = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), a), - Binop::normal(Operator::Mult), - (&Some(elem_ty), b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - let result = gen_binop_expr_with_values( - generator, - ctx, - (&Some(elem_ty), result), - Binop::normal(Operator::Add), - (&Some(elem_ty), a_mul_b), - ctx.current_loc, - )? - .unwrap() - .to_basic_value_enum(ctx, generator, elem_ty)?; - ctx.builder.build_store(result_addr, result).unwrap(); - - Ok(()) - }, - llvm_usize.const_int(1, false), - )?; - - let result = ctx.builder.build_load(result_addr, "").unwrap(); - Ok(result) - })?; - - Ok(ndarray) -} - /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx>( context: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/object/ndarray/matmul.rs b/nac3core/src/codegen/object/ndarray/matmul.rs new file mode 100644 index 00000000..bcdf60dd --- /dev/null +++ b/nac3core/src/codegen/object/ndarray/matmul.rs @@ -0,0 +1,216 @@ +use std::cmp::max; + +use nac3parser::ast::Operator; + +use super::{util::gen_for_model, NDArrayObject, NDArrayOut}; +use crate::{ + codegen::{ + expr::gen_binop_expr_with_values, irrt::call_nac3_ndarray_matmul_calculate_shapes, + model::*, object::ndarray::indexing::RustNDIndex, CodeGenContext, CodeGenerator, + }, + typecheck::{magic_methods::Binop, typedef::Type}, +}; + +/// Perform `np.einsum("...ij,...jk->...ik", in_a, in_b)`. +/// +/// `dst_dtype` defines the dtype of the returned ndarray. +fn matmul_at_least_2d<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + dst_dtype: Type, + in_a: NDArrayObject<'ctx>, + in_b: NDArrayObject<'ctx>, +) -> NDArrayObject<'ctx> { + assert!(in_a.ndims >= 2); + assert!(in_b.ndims >= 2); + + // Deduce ndims of the result of matmul. + let ndims_int = max(in_a.ndims, in_b.ndims); + let ndims = Int(SizeT).const_int(generator, ctx.ctx, ndims_int, false); + + // Broadcasts `in_a.shape[:-2]` and `in_b.shape[:-2]` together and allocate the + // destination ndarray to store the result of matmul. + let (lhs, rhs, dst) = { + let in_lhs_ndims = in_a.ndims_llvm(generator, ctx.ctx); + let in_lhs_shape = in_a.instance.get(generator, ctx, |f| f.shape); + let in_rhs_ndims = in_b.ndims_llvm(generator, ctx.ctx); + let in_rhs_shape = in_b.instance.get(generator, ctx, |f| f.shape); + let lhs_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value); + let rhs_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value); + let dst_shape = Int(SizeT).array_alloca(generator, ctx, ndims.value); + + // Matmul dimension compatibility is checked here. + call_nac3_ndarray_matmul_calculate_shapes( + generator, + ctx, + in_lhs_ndims, + in_lhs_shape, + in_rhs_ndims, + in_rhs_shape, + ndims, + lhs_shape, + rhs_shape, + dst_shape, + ); + + let lhs = in_a.broadcast_to(generator, ctx, ndims_int, lhs_shape); + let rhs = in_b.broadcast_to(generator, ctx, ndims_int, rhs_shape); + + let dst = NDArrayObject::alloca(generator, ctx, dst_dtype, ndims_int); + dst.copy_shape_from_array(generator, ctx, dst_shape); + dst.create_data(generator, ctx); + + (lhs, rhs, dst) + }; + + let len = lhs.instance.get(generator, ctx, |f| f.shape).get_index_const( + generator, + ctx, + i64::try_from(ndims_int - 1).unwrap(), + ); + + let at_row = i64::try_from(ndims_int - 2).unwrap(); + let at_col = i64::try_from(ndims_int - 1).unwrap(); + + let dst_dtype_llvm = ctx.get_llvm_type(generator, dst_dtype); + let dst_zero = dst_dtype_llvm.const_zero(); + + dst.foreach(generator, ctx, |generator, ctx, _, hdl| { + let pdst_ij = hdl.get_pointer(generator, ctx); + + ctx.builder.build_store(pdst_ij, dst_zero).unwrap(); + + let indices = hdl.get_indices(); + let i = indices.get_index_const(generator, ctx, at_row); + let j = indices.get_index_const(generator, ctx, at_col); + + let num_0 = Int(SizeT).const_int(generator, ctx.ctx, 0, false); + let num_1 = Int(SizeT).const_int(generator, ctx.ctx, 1, false); + + gen_for_model(generator, ctx, num_0, len, num_1, |generator, ctx, _, k| { + // `indices` is modified to index into `a` and `b`, and restored. + indices.set_index_const(ctx, at_row, i); + indices.set_index_const(ctx, at_col, k); + let a_ik = lhs.get_scalar_by_indices(generator, ctx, indices); + + indices.set_index_const(ctx, at_row, k); + indices.set_index_const(ctx, at_col, j); + let b_kj = rhs.get_scalar_by_indices(generator, ctx, indices); + + // Restore `indices`. + indices.set_index_const(ctx, at_row, i); + indices.set_index_const(ctx, at_col, j); + + // x = a_[...]ik * b_[...]kj + let x = gen_binop_expr_with_values( + generator, + ctx, + (&Some(lhs.dtype), a_ik.value), + Binop::normal(Operator::Mult), + (&Some(rhs.dtype), b_kj.value), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + + // dst_[...]ij += x + let dst_ij = ctx.builder.build_load(pdst_ij, "").unwrap(); + let dst_ij = gen_binop_expr_with_values( + generator, + ctx, + (&Some(dst_dtype), dst_ij), + Binop::normal(Operator::Add), + (&Some(dst_dtype), x), + ctx.current_loc, + )? + .unwrap() + .to_basic_value_enum(ctx, generator, dst_dtype)?; + ctx.builder.build_store(pdst_ij, dst_ij).unwrap(); + + Ok(()) + }) + }) + .unwrap(); + + dst +} + +impl<'ctx> NDArrayObject<'ctx> { + /// Perform `np.matmul` according to the rules in + /// . + /// + /// This function always return an [`NDArrayObject`]. You may want to use [`NDArrayObject::split_unsized`] + /// to handle when the output could be a scalar. + /// + /// `dst_dtype` defines the dtype of the returned ndarray. + pub fn matmul( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + a: Self, + b: Self, + out: NDArrayOut<'ctx>, + ) -> Self { + // Sanity check, but type inference should prevent this. + assert!(a.ndims > 0 && b.ndims > 0, "np.matmul disallows scalar input"); + + /* + If both arguments are 2-D they are multiplied like conventional matrices. + If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indices and broadcast accordingly. + If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed. + If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed. + */ + + let new_a = if a.ndims == 1 { + // Prepend 1 to its dimensions + a.index(generator, ctx, &[RustNDIndex::NewAxis, RustNDIndex::Ellipsis]) + } else { + a + }; + + let new_b = if b.ndims == 1 { + // Append 1 to its dimensions + b.index(generator, ctx, &[RustNDIndex::Ellipsis, RustNDIndex::NewAxis]) + } else { + b + }; + + // NOTE: `result` will always be a newly allocated ndarray. + // Current implementation cannot do in-place matrix muliplication. + let mut result = matmul_at_least_2d(generator, ctx, out.get_dtype(), new_a, new_b); + + // Postprocessing on the result to remove prepended/appended axes. + let mut postindices = vec![]; + let zero = Int(Int32).const_0(generator, ctx.ctx); + + if a.ndims == 1 { + // Remove the prepended 1 + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if b.ndims == 1 { + // Remove the appended 1 + postindices.push(RustNDIndex::Ellipsis); + postindices.push(RustNDIndex::SingleElement(zero)); + } + + if !postindices.is_empty() { + result = result.index(generator, ctx, &postindices); + } + + match out { + NDArrayOut::NewNDArray { .. } => result, + NDArrayOut::WriteToNDArray { ndarray: out_ndarray } => { + let result_shape = result.instance.get(generator, ctx, |f| f.shape); + out_ndarray.assert_can_be_written_by_out( + generator, + ctx, + result.ndims, + result_shape, + ); + + out_ndarray.copy_data_from(generator, ctx, result); + out_ndarray + } + } + } +} diff --git a/nac3core/src/codegen/object/ndarray/mod.rs b/nac3core/src/codegen/object/ndarray/mod.rs index 276ee3a7..48729645 100644 --- a/nac3core/src/codegen/object/ndarray/mod.rs +++ b/nac3core/src/codegen/object/ndarray/mod.rs @@ -27,6 +27,7 @@ pub mod broadcast; pub mod factory; pub mod indexing; pub mod map; +pub mod matmul; pub mod nditer; pub mod shape_util; pub mod view; diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 48ea22a0..a997444a 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(250)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(257)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index c381eb76..df4517b4 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar239]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar239\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar246]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar246\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index f57410fc..b1d9bf99 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(252)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(259)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(264)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 45bf04fb..83746579 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar238, typevar239]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar238\", \"typevar239\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar245, typevar246]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar245\", \"typevar246\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index a25fac3c..aea0aee4 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(265)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(266)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(273)]\n}\n", ] diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 60972f03..40bbdeab 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -7,12 +7,12 @@ use nac3parser::ast::{Cmpop, Operator, StrRef, Unaryop}; use super::{ type_inferencer::*, - typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, + typedef::{into_var_map, FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, }; use crate::{ symbol_resolver::SymbolValue, toplevel::{ - helper::PrimDef, + helper::{extract_ndims, PrimDef}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, }, }; @@ -175,19 +175,8 @@ pub fn impl_binop( ops: &[Operator], ) { with_fields(unifier, ty, |unifier, fields| { - let (other_ty, other_var_id) = if other_ty.len() == 1 { - (other_ty[0], None) - } else { - let tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); - (tvar.ty, Some(tvar.id)) - }; - - let function_vars = if let Some(var_id) = other_var_id { - vec![(var_id, other_ty)].into_iter().collect::() - } else { - VarMap::new() - }; - + let other_tvar = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None); + let function_vars = into_var_map([other_tvar]); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).ty); for (base_op, variant) in iproduct!(ops, [BinopVariant::Normal, BinopVariant::AugAssign]) { @@ -198,7 +187,7 @@ pub fn impl_binop( ret: ret_ty, vars: function_vars.clone(), args: vec![FuncArg { - ty: other_ty, + ty: other_tvar.ty, default_value: None, name: "other".into(), is_vararg: false, @@ -541,36 +530,43 @@ pub fn typeof_binop( } } - let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); - let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() + let (lhs_dtype, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); + let lhs_ndims = extract_ndims(unifier, lhs_ndims); + + let (rhs_dtype, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); + let rhs_ndims = extract_ndims(unifier, rhs_ndims); + + if !(unifier.unioned(lhs_dtype, primitives.float) + && unifier.unioned(rhs_dtype, primitives.float)) + { + return Err(format!( + "ndarray.__matmul__ only supports float64 operations, but LHS has type {} and RHS has type {}", + unifier.stringify(lhs), + unifier.stringify(rhs) + )); + } + + // Deduce the ndims of the resulting ndarray. + // If this is 0 (an unsized ndarray), matmul returns a scalar just like NumPy. + let result_ndims = match (lhs_ndims, rhs_ndims) { + (0, _) | (_, 0) => { + return Err( + "ndarray.__matmul__ does not allow unsized ndarray input".to_string() + ) } - _ => unreachable!(), - }; - let (_, rhs_ndims) = unpack_ndarray_var_tys(unifier, rhs); - let rhs_ndims = match &*unifier.get_ty_immutable(rhs_ndims) { - TypeEnum::TLiteral { values, .. } => { - assert_eq!(values.len(), 1); - u64::try_from(values[0].clone()).unwrap() - } - _ => unreachable!(), + (1, 1) => 0, + (1, _) => rhs_ndims - 1, + (_, 1) => lhs_ndims - 1, + (m, n) => max(m, n), }; - match (lhs_ndims, rhs_ndims) { - (2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, - (lhs, rhs) if lhs == 0 || rhs == 0 => { - return Err(format!( - "Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", - u8::from(rhs == 0) - )) - } - (lhs, rhs) => { - return Err(format!( - "ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported" - )) - } + if result_ndims == 0 { + // If the result is unsized, NumPy returns a scalar. + primitives.float + } else { + let result_ndims_ty = + unifier.get_fresh_literal(vec![SymbolValue::U64(result_ndims)], None); + make_ndarray_ty(unifier, primitives, Some(primitives.float), Some(result_ndims_ty)) } } @@ -773,7 +769,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); - impl_matmul(unifier, store, ndarray_t, &[ndarray_t], Some(ndarray_t)); + impl_matmul(unifier, store, ndarray_t, &[ndarray_unsized_t], None); impl_sign(unifier, store, ndarray_t, Some(ndarray_t)); impl_invert(unifier, store, ndarray_t, Some(ndarray_t)); impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); -- 2.47.0 From 693b7f37742f5d89c6501721fe073095f743e0d1 Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 20 Aug 2024 21:18:59 +0800 Subject: [PATCH 2/2] core/ndstrides: implement np_dot() for scalars and 1D --- nac3core/src/codegen/numpy.rs | 124 +++++++++++++++++------------- nac3core/src/toplevel/builtins.rs | 4 +- 2 files changed, 72 insertions(+), 56 deletions(-) diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 8d00c7cc..e9f5dbd6 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -21,9 +21,12 @@ use super::{ model::*, object::{ any::AnyObject, - ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject}, + ndarray::{nditer::NDIterHandle, shape_util::parse_numpy_int_sequence, NDArrayObject}, + }, + stmt::{ + gen_for_callback, gen_for_callback_incrementing, gen_for_range_callback, + gen_if_else_expr_callback, }, - stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback}, CodeGenContext, CodeGenerator, }; use crate::{ @@ -1704,77 +1707,88 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>( ) -> Result, String> { const FN_NAME: &str = "ndarray_dot"; let (x1_ty, x1) = x1; - let (_, x2) = x2; - - let llvm_usize = generator.get_size_type(ctx.ctx); + let (x2_ty, x2) = x2; match (x1, x2) { - (BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => { - let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None); - let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None); + (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) => { + let a = AnyObject { ty: x1_ty, value: x1 }; + let b = AnyObject { ty: x2_ty, value: x2 }; - let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); - let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None)); + let a = NDArrayObject::from_object(generator, ctx, a); + let b = NDArrayObject::from_object(generator, ctx, b); + // TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html. + assert_eq!(a.ndims, 1); + assert_eq!(b.ndims, 1); + let common_dtype = a.dtype; + + // Check shapes. + let a_size = a.size(generator, ctx); + let b_size = b.size(generator, ctx); + let same_shape = a_size.compare(ctx, IntPredicate::EQ, b_size); ctx.make_assert( generator, - ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(), + same_shape.value, "0:ValueError", - "shapes ({0}), ({1}) not aligned", - [Some(n1_sz), Some(n2_sz), None], + "shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)", + [Some(a_size.value), Some(b_size.value), None], ctx.current_loc, ); - let identity = - unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) }; - let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap(); - ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap(); + let dtype_llvm = ctx.get_llvm_type(generator, common_dtype); - gen_for_callback_incrementing( + let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap(); + ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap(); + + // Do dot product. + gen_for_callback( generator, ctx, - None, - llvm_usize.const_zero(), - (n1_sz, false), - |generator, ctx, _, idx| { - let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) }; - let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) }; + Some("np_dot"), + |generator, ctx| { + let a_iter = NDIterHandle::new(generator, ctx, a); + let b_iter = NDIterHandle::new(generator, ctx, b); + Ok((a_iter, b_iter)) + }, + |generator, ctx, (a_iter, _b_iter)| { + // Only a_iter drives the condition, b_iter should have the same status. + Ok(a_iter.has_element(generator, ctx).value) + }, + |generator, ctx, _hooks, (a_iter, b_iter)| { + let a_scalar = a_iter.get_scalar(generator, ctx).value; + let b_scalar = b_iter.get_scalar(generator, ctx).value; - let product = match elem1 { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_mul(e1, elem2.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_mul(e1, elem2.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx), + let old_result = ctx.builder.build_load(result, "").unwrap(); + let new_result: BasicValueEnum<'ctx> = match old_result { + BasicValueEnum::IntValue(old_result) => { + let a_scalar = a_scalar.into_int_value(); + let b_scalar = b_scalar.into_int_value(); + let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_int_add(old_result, x, "").unwrap().into() + } + BasicValueEnum::FloatValue(old_result) => { + let a_scalar = a_scalar.into_float_value(); + let b_scalar = b_scalar.into_float_value(); + let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap(); + ctx.builder.build_float_add(old_result, x, "").unwrap().into() + } + _ => { + panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype)); + } }; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - let acc_val = match acc_val { - BasicValueEnum::IntValue(e1) => ctx - .builder - .build_int_add(e1, product.into_int_value(), "") - .unwrap() - .as_basic_value_enum(), - BasicValueEnum::FloatValue(e1) => ctx - .builder - .build_float_add(e1, product.into_float_value(), "") - .unwrap() - .as_basic_value_enum(), - _ => codegen_unreachable!(ctx), - }; - ctx.builder.build_store(acc, acc_val).unwrap(); + ctx.builder.build_store(result, new_result).unwrap(); Ok(()) }, - llvm_usize.const_int(1, false), - )?; - let acc_val = ctx.builder.build_load(acc, "").unwrap(); - Ok(acc_val) + |generator, ctx, (a_iter, b_iter)| { + a_iter.next(generator, ctx); + b_iter.next(generator, ctx); + Ok(()) + }, + ) + .unwrap(); + + Ok(ctx.builder.build_load(result, "").unwrap()) } (BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => { Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum()) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index a47963a7..526599cc 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -2080,10 +2080,12 @@ impl<'a> BuiltinBuilder<'a> { Box::new(move |ctx, _, fun, args, generator| { let x1_ty = fun.0.args[0].ty; let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?; + let x2_ty = fun.0.args[1].ty; let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?; - Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?)) + let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?; + Ok(Some(result)) }), ), -- 2.47.0