forked from M-Labs/nac3
core: improve binop and cmpop error messages
This commit is contained in:
parent
0a732691c9
commit
f52086b706
@ -1202,11 +1202,11 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
if is_aug_assign {
|
||||
if op.variant == BinopVariant::AugAssign {
|
||||
todo!("Augmented assignment operators not implemented for lists")
|
||||
}
|
||||
|
||||
match op {
|
||||
match op.base {
|
||||
Operator::Add => {
|
||||
debug_assert_eq!(ty1.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
|
||||
debug_assert_eq!(ty2.obj_id(&ctx.unifier), Some(PrimDef::List.id()));
|
||||
|
@ -486,18 +486,20 @@ pub fn typeof_binop(
|
||||
lhs: Type,
|
||||
rhs: Type,
|
||||
) -> Result<Option<Type>, String> {
|
||||
let op = Binop { base: op, variant: BinopVariant::Normal };
|
||||
|
||||
let is_left_list = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
|
||||
let is_right_list = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id());
|
||||
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
||||
|
||||
Ok(Some(match op {
|
||||
Ok(Some(match op.base {
|
||||
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
|
||||
if is_left_list || is_right_list {
|
||||
if ![Operator::Add, Operator::Mult].contains(&op) {
|
||||
if ![Operator::Add, Operator::Mult].contains(&op.base) {
|
||||
return Err(format!(
|
||||
"Binary operator {} not supported for list",
|
||||
binop_name(op)
|
||||
op.op_info().symbol
|
||||
));
|
||||
}
|
||||
|
||||
|
@ -1,11 +1,14 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
|
||||
use crate::typecheck::typedef::TypeEnum;
|
||||
use crate::typecheck::{magic_methods::HasOpInfo, typedef::TypeEnum};
|
||||
|
||||
use super::typedef::{RecordKey, Type, Unifier};
|
||||
use super::{
|
||||
magic_methods::Binop,
|
||||
typedef::{RecordKey, Type, Unifier},
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use nac3parser::ast::{Location, StrRef};
|
||||
use nac3parser::ast::{Cmpop, Location, StrRef};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TypeErrorKind {
|
||||
@ -26,6 +29,18 @@ pub enum TypeErrorKind {
|
||||
expected: Type,
|
||||
got: Type,
|
||||
},
|
||||
UnsupportedBinaryOpTypes {
|
||||
operator: Binop,
|
||||
lhs_type: Type,
|
||||
rhs_type: Type,
|
||||
expected_rhs_type: Type,
|
||||
},
|
||||
UnsupportedComparsionOpTypes {
|
||||
operator: Cmpop,
|
||||
lhs_type: Type,
|
||||
rhs_type: Type,
|
||||
expected_rhs_type: Type,
|
||||
},
|
||||
FieldUnificationError {
|
||||
field: RecordKey,
|
||||
types: (Type, Type),
|
||||
@ -101,6 +116,26 @@ impl<'a> Display for DisplayTypeError<'a> {
|
||||
let args = missing_arg_names.iter().join(", ");
|
||||
write!(f, "Missing arguments: {args}")
|
||||
}
|
||||
UnsupportedBinaryOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
|
||||
let op_symbol = operator.op_info().symbol;
|
||||
|
||||
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
|
||||
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
|
||||
let expected_rhs_type_str =
|
||||
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
|
||||
|
||||
write!(f, "Unsupported operand type(s) for {op_symbol}: '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
|
||||
}
|
||||
UnsupportedComparsionOpTypes { operator, lhs_type, rhs_type, expected_rhs_type } => {
|
||||
let op_symbol = operator.op_info().symbol;
|
||||
|
||||
let lhs_type_str = self.unifier.stringify_with_notes(*lhs_type, &mut notes);
|
||||
let rhs_type_str = self.unifier.stringify_with_notes(*rhs_type, &mut notes);
|
||||
let expected_rhs_type_str =
|
||||
self.unifier.stringify_with_notes(*expected_rhs_type, &mut notes);
|
||||
|
||||
write!(f, "'{op_symbol}' not supported between instances of '{lhs_type_str}' and '{rhs_type_str}' (right operand should have type {expected_rhs_type_str})")
|
||||
}
|
||||
UnknownArgName(name) => {
|
||||
write!(f, "Unknown argument name: {name}")
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ use std::iter::once;
|
||||
use std::ops::Not;
|
||||
use std::{cell::RefCell, sync::Arc};
|
||||
|
||||
use super::typedef::OperatorInfo;
|
||||
use super::{
|
||||
magic_methods::*,
|
||||
type_error::TypeError,
|
||||
@ -641,6 +642,7 @@ impl<'a> Inferencer<'a> {
|
||||
obj: Type,
|
||||
params: Vec<Type>,
|
||||
ret: Option<Type>,
|
||||
operator_info: Option<OperatorInfo>,
|
||||
) -> InferenceResult {
|
||||
if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) {
|
||||
if class_params.is_empty() {
|
||||
@ -654,6 +656,7 @@ impl<'a> Inferencer<'a> {
|
||||
ret: sign.ret,
|
||||
fun: RefCell::new(None),
|
||||
loc: Some(location),
|
||||
operator_info,
|
||||
};
|
||||
if let Some(ret) = ret {
|
||||
self.unifier
|
||||
@ -688,6 +691,7 @@ impl<'a> Inferencer<'a> {
|
||||
ret,
|
||||
fun: RefCell::new(None),
|
||||
loc: Some(location),
|
||||
operator_info,
|
||||
});
|
||||
self.calls.insert(location.into(), call);
|
||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
|
||||
@ -1523,6 +1527,7 @@ impl<'a> Inferencer<'a> {
|
||||
fun: RefCell::new(None),
|
||||
ret: sign.ret,
|
||||
loc: Some(location),
|
||||
operator_info: None,
|
||||
};
|
||||
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
|
||||
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
|
||||
@ -1545,6 +1550,7 @@ impl<'a> Inferencer<'a> {
|
||||
fun: RefCell::new(None),
|
||||
ret,
|
||||
loc: Some(location),
|
||||
operator_info: None,
|
||||
});
|
||||
self.calls.insert(location.into(), call);
|
||||
let call = self.unifier.add_ty(TypeEnum::TCall(vec![call]));
|
||||
@ -1765,7 +1771,14 @@ impl<'a> Inferencer<'a> {
|
||||
}
|
||||
};
|
||||
|
||||
self.build_method_call(location, method.into(), left_ty, vec![right_ty], ret)
|
||||
self.build_method_call(
|
||||
location,
|
||||
method.into(),
|
||||
left_ty,
|
||||
vec![right_ty],
|
||||
ret,
|
||||
Some(OperatorInfo::IsBinaryOp { self_type: left.custom.unwrap(), operator: op }),
|
||||
)
|
||||
}
|
||||
|
||||
fn infer_unary_ops(
|
||||
@ -1779,7 +1792,14 @@ impl<'a> Inferencer<'a> {
|
||||
let ret = typeof_unaryop(self.unifier, self.primitives, op, operand.custom.unwrap())
|
||||
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?;
|
||||
|
||||
self.build_method_call(location, method, operand.custom.unwrap(), vec![], ret)
|
||||
self.build_method_call(
|
||||
location,
|
||||
method,
|
||||
operand.custom.unwrap(),
|
||||
vec![],
|
||||
ret,
|
||||
Some(OperatorInfo::IsUnaryOp { self_type: operand.custom.unwrap(), operator: op }),
|
||||
)
|
||||
}
|
||||
|
||||
fn infer_compare(
|
||||
@ -1825,6 +1845,10 @@ impl<'a> Inferencer<'a> {
|
||||
a.custom.unwrap(),
|
||||
vec![b.custom.unwrap()],
|
||||
ret,
|
||||
Some(OperatorInfo::IsComparisonOp {
|
||||
self_type: left.custom.unwrap(),
|
||||
operator: *c,
|
||||
}),
|
||||
)?);
|
||||
}
|
||||
|
||||
|
@ -8,12 +8,15 @@ use std::rc::Rc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::{borrow::Cow, collections::HashSet};
|
||||
|
||||
use nac3parser::ast::{Location, StrRef};
|
||||
use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop};
|
||||
|
||||
use super::magic_methods::Binop;
|
||||
use super::type_error::{TypeError, TypeErrorKind};
|
||||
use super::unification_table::{UnificationKey, UnificationTable};
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::{helper::PrimDef, DefinitionId, TopLevelContext, TopLevelDef};
|
||||
use crate::toplevel::helper::PrimDef;
|
||||
use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef};
|
||||
use crate::typecheck::magic_methods::OpInfo;
|
||||
use crate::typecheck::type_inferencer::PrimitiveStore;
|
||||
|
||||
#[cfg(test)]
|
||||
@ -73,6 +76,28 @@ pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator<Item = TypeVar> + '_ {
|
||||
var_map.iter().map(|(&id, &ty)| TypeVar { id, ty })
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OperatorInfo {
|
||||
/// The call was written as an unary operation, e.g., `~a` or `not a`.
|
||||
IsUnaryOp {
|
||||
/// The [`Type`] of the `self` object
|
||||
self_type: Type,
|
||||
operator: Unaryop,
|
||||
},
|
||||
/// The call was written as a binary operation, e.g., `a + b` or `a += b`.
|
||||
IsBinaryOp {
|
||||
/// The [`Type`] of the `self` object
|
||||
self_type: Type,
|
||||
operator: Binop,
|
||||
},
|
||||
/// The call was written as a binary comparison operation, e.g., `a < b`.
|
||||
IsComparisonOp {
|
||||
/// The [`Type`] of the `self` object
|
||||
self_type: Type,
|
||||
operator: Cmpop,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Call {
|
||||
pub posargs: Vec<Type>,
|
||||
@ -80,6 +105,9 @@ pub struct Call {
|
||||
pub ret: Type,
|
||||
pub fun: RefCell<Option<Type>>,
|
||||
pub loc: Option<Location>,
|
||||
|
||||
/// Details about the associated Python user operator expression of this call, if any.
|
||||
pub operator_info: Option<OperatorInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@ -627,10 +655,67 @@ impl Unifier {
|
||||
let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() };
|
||||
|
||||
// Get details about the input arguments
|
||||
let Call { posargs, kwargs, ret, fun, loc } = call;
|
||||
let Call { posargs, kwargs, ret, fun, loc, operator_info } = call;
|
||||
let num_args = posargs.len() + kwargs.len();
|
||||
|
||||
// Now we check the arguments against the parameters
|
||||
// Now we check the arguments against the parameters,
|
||||
// and depending on what `call_info` is, we might change how the behavior `unify_call()`
|
||||
// in hopes to improve user error messages when type checking fails.
|
||||
match operator_info {
|
||||
Some(OperatorInfo::IsBinaryOp { self_type, operator }) => {
|
||||
// The call is written in the form of (say) `a + b`.
|
||||
// Technically, it is `a.__add__(b)`, and they have the following constraints:
|
||||
assert_eq!(posargs.len(), 1);
|
||||
assert_eq!(kwargs.len(), 0);
|
||||
assert_eq!(num_params, 1);
|
||||
|
||||
let other_type = posargs[0]; // the second operand
|
||||
let expected_other_type = signature.args[0].ty;
|
||||
|
||||
let ok = self.unify_impl(expected_other_type, other_type, false).is_ok();
|
||||
if !ok {
|
||||
self.restore_snapshot();
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::UnsupportedBinaryOpTypes {
|
||||
operator: *operator,
|
||||
lhs_type: *self_type,
|
||||
rhs_type: other_type,
|
||||
expected_rhs_type: expected_other_type,
|
||||
},
|
||||
*loc,
|
||||
));
|
||||
}
|
||||
}
|
||||
Some(OperatorInfo::IsComparisonOp { self_type, operator })
|
||||
if OpInfo::supports_cmpop(*operator) // Otherwise that comparison operator is not supported.
|
||||
=>
|
||||
{
|
||||
// The call is written in the form of (say) `a <= b`.
|
||||
// Technically, it is `a.__le__(b)`, and they have the following constraints:
|
||||
assert_eq!(posargs.len(), 1);
|
||||
assert_eq!(kwargs.len(), 0);
|
||||
assert_eq!(num_params, 1);
|
||||
|
||||
let other_type = posargs[0]; // the second operand
|
||||
let expected_other_type = signature.args[0].ty;
|
||||
|
||||
let ok = self.unify_impl(expected_other_type, other_type, false).is_ok();
|
||||
if !ok {
|
||||
self.restore_snapshot();
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::UnsupportedComparsionOpTypes {
|
||||
operator: *operator,
|
||||
lhs_type: *self_type,
|
||||
rhs_type: other_type,
|
||||
expected_rhs_type: expected_other_type,
|
||||
},
|
||||
*loc,
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Handle [`CallInfo::IsNormalFunctionCall`] and other uninteresting variants
|
||||
// of [`CallInfo`] (e.g, `CallInfo::IsUnaryOp` and unsupported comparison operators)
|
||||
|
||||
// Helper lambdas
|
||||
let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| {
|
||||
@ -691,7 +776,10 @@ impl Unifier {
|
||||
|
||||
let Some(param_info) = param_info_by_name.get_mut(¶m_name) else {
|
||||
self.restore_snapshot();
|
||||
return Err(TypeError::new(TypeErrorKind::UnknownArgName(param_name), *loc));
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::UnknownArgName(param_name),
|
||||
*loc,
|
||||
));
|
||||
};
|
||||
|
||||
if param_info.has_been_supplied {
|
||||
@ -716,12 +804,17 @@ impl Unifier {
|
||||
// unsupplied required parameters, and throw an error if they exist.
|
||||
let missing_arg_names = param_info_by_name
|
||||
.values()
|
||||
.filter(|param_info| param_info.param.is_required() && !param_info.has_been_supplied)
|
||||
.filter(|param_info| {
|
||||
param_info.param.is_required() && !param_info.has_been_supplied
|
||||
})
|
||||
.map(|param_info| param_info.param.name)
|
||||
.collect_vec();
|
||||
if !missing_arg_names.is_empty() {
|
||||
self.restore_snapshot();
|
||||
return Err(TypeError::new(TypeErrorKind::MissingArgs { missing_arg_names }, *loc));
|
||||
return Err(TypeError::new(
|
||||
TypeErrorKind::MissingArgs { missing_arg_names },
|
||||
*loc,
|
||||
));
|
||||
}
|
||||
|
||||
// Finally, check the Call's return type
|
||||
@ -732,6 +825,8 @@ impl Unifier {
|
||||
}
|
||||
err
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
*fun.borrow_mut() = Some(b);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user