From c1dee4205961fb3491f8743b0ef5740989dbfe7f Mon Sep 17 00:00:00 2001 From: lyken Date: Tue, 27 Aug 2024 11:47:40 +0800 Subject: [PATCH] core/typecheck: add missing typecheck in matmul --- nac3core/src/typecheck/magic_methods.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 325f837a..9344dd21 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -520,6 +520,23 @@ pub fn typeof_binop( } Operator::MatMult => { + // NOTE: NumPy matmul's LHS and RHS must both be ndarrays. Scalars are not allowed. + match (&*unifier.get_ty(lhs), &*unifier.get_ty(rhs)) { + ( + TypeEnum::TObj { obj_id: lhs_obj_id, .. }, + TypeEnum::TObj { obj_id: rhs_obj_id, .. }, + ) if *lhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap() + && *rhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap() => + { + // LHS and RHS have valid types + } + _ => { + let lhs_str = unifier.stringify(lhs); + let rhs_str = unifier.stringify(rhs); + return Err(format!("ndarray.__matmul__ only accepts ndarray operands, but left operand has type {lhs_str}, and right operand has type {rhs_str}")); + } + } + let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { TypeEnum::TLiteral { values, .. } => {