forked from M-Labs/nac3
1
0
Fork 0

core/typecheck: add missing typecheck in matmul

This commit is contained in:
lyken 2024-08-27 11:47:40 +08:00 committed by sb10q
parent 308edb8237
commit 22c4d25802
1 changed files with 17 additions and 0 deletions

View File

@ -520,6 +520,23 @@ pub fn typeof_binop(
} }
Operator::MatMult => { Operator::MatMult => {
// NOTE: NumPy matmul's LHS and RHS must both be ndarrays. Scalars are not allowed.
match (&*unifier.get_ty(lhs), &*unifier.get_ty(rhs)) {
(
TypeEnum::TObj { obj_id: lhs_obj_id, .. },
TypeEnum::TObj { obj_id: rhs_obj_id, .. },
) if *lhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap()
&& *rhs_obj_id == primitives.ndarray.obj_id(unifier).unwrap() =>
{
// LHS and RHS have valid types
}
_ => {
let lhs_str = unifier.stringify(lhs);
let rhs_str = unifier.stringify(rhs);
return Err(format!("ndarray.__matmul__ only accepts ndarray operands, but left operand has type {lhs_str}, and right operand has type {rhs_str}"));
}
}
let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs); let (_, lhs_ndims) = unpack_ndarray_var_tys(unifier, lhs);
let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) { let lhs_ndims = match &*unifier.get_ty_immutable(lhs_ndims) {
TypeEnum::TLiteral { values, .. } => { TypeEnum::TLiteral { values, .. } => {