forked from M-Labs/nac3
core: improve function call errors
This commit is contained in:
parent
ca8459dc7b
commit
0ec967a468
@ -4,15 +4,22 @@ use std::fmt::Display;
|
|||||||
use crate::typecheck::typedef::TypeEnum;
|
use crate::typecheck::typedef::TypeEnum;
|
||||||
|
|
||||||
use super::typedef::{RecordKey, Type, Unifier};
|
use super::typedef::{RecordKey, Type, Unifier};
|
||||||
|
use itertools::Itertools;
|
||||||
use nac3parser::ast::{Location, StrRef};
|
use nac3parser::ast::{Location, StrRef};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum TypeErrorKind {
|
pub enum TypeErrorKind {
|
||||||
TooManyArguments {
|
GotMultipleValues {
|
||||||
expected: usize,
|
name: StrRef,
|
||||||
got: usize,
|
},
|
||||||
|
TooManyArguments {
|
||||||
|
expected_min_count: usize,
|
||||||
|
expected_max_count: usize,
|
||||||
|
got_count: usize,
|
||||||
|
},
|
||||||
|
MissingArgs {
|
||||||
|
missing_arg_names: Vec<StrRef>,
|
||||||
},
|
},
|
||||||
MissingArgs(String),
|
|
||||||
UnknownArgName(StrRef),
|
UnknownArgName(StrRef),
|
||||||
IncorrectArgType {
|
IncorrectArgType {
|
||||||
name: StrRef,
|
name: StrRef,
|
||||||
@ -78,10 +85,20 @@ impl<'a> Display for DisplayTypeError<'a> {
|
|||||||
use TypeErrorKind::*;
|
use TypeErrorKind::*;
|
||||||
let mut notes = Some(HashMap::new());
|
let mut notes = Some(HashMap::new());
|
||||||
match &self.err.kind {
|
match &self.err.kind {
|
||||||
TooManyArguments { expected, got } => {
|
GotMultipleValues { name } => {
|
||||||
write!(f, "Too many arguments. Expected {expected} but got {got}")
|
write!(f, "For multiple values for parameter {name}")
|
||||||
}
|
}
|
||||||
MissingArgs(args) => {
|
TooManyArguments { expected_min_count, expected_max_count, got_count } => {
|
||||||
|
debug_assert!(expected_min_count <= expected_max_count);
|
||||||
|
if expected_min_count == expected_max_count {
|
||||||
|
let expected_count = expected_min_count; // or expected_max_count
|
||||||
|
write!(f, "Too many arguments. Expected {expected_count} but got {got_count}")
|
||||||
|
} else {
|
||||||
|
write!(f, "Too many arguments. Expected {expected_min_count} to {expected_max_count} arguments but got {got_count}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MissingArgs { missing_arg_names } => {
|
||||||
|
let args = missing_arg_names.iter().join(", ");
|
||||||
write!(f, "Missing arguments: {args}")
|
write!(f, "Missing arguments: {args}")
|
||||||
}
|
}
|
||||||
UnknownArgName(name) => {
|
UnknownArgName(name) => {
|
||||||
@ -90,7 +107,7 @@ impl<'a> Display for DisplayTypeError<'a> {
|
|||||||
IncorrectArgType { name, expected, got } => {
|
IncorrectArgType { name, expected, got } => {
|
||||||
let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
|
let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
|
||||||
let got = self.unifier.stringify_with_notes(*got, &mut notes);
|
let got = self.unifier.stringify_with_notes(*got, &mut notes);
|
||||||
write!(f, "Incorrect argument type for {name}. Expected {expected}, but got {got}")
|
write!(f, "Incorrect argument type for parameter {name}. Expected {expected}, but got {got}")
|
||||||
}
|
}
|
||||||
FieldUnificationError { field, types, loc } => {
|
FieldUnificationError { field, types, loc } => {
|
||||||
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);
|
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);
|
||||||
|
@ -642,14 +642,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
let required: Vec<_> = sign
|
self.unifier.unify_call(&call, ty, sign).map_err(|e| {
|
||||||
.args
|
|
||||||
.iter()
|
|
||||||
.filter(|v| v.default_value.is_none())
|
|
||||||
.map(|v| v.name)
|
|
||||||
.rev()
|
|
||||||
.collect();
|
|
||||||
self.unifier.unify_call(&call, ty, sign, &required).map_err(|e| {
|
|
||||||
HashSet::from([e
|
HashSet::from([e
|
||||||
.at(Some(location))
|
.at(Some(location))
|
||||||
.to_display(self.unifier)
|
.to_display(self.unifier)
|
||||||
@ -1347,16 +1340,9 @@ impl<'a> Inferencer<'a> {
|
|||||||
ret: sign.ret,
|
ret: sign.ret,
|
||||||
loc: Some(location),
|
loc: Some(location),
|
||||||
};
|
};
|
||||||
let required: Vec<_> = sign
|
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
|
||||||
.args
|
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
|
||||||
.iter()
|
})?;
|
||||||
.filter(|v| v.default_value.is_none())
|
|
||||||
.map(|v| v.name)
|
|
||||||
.rev()
|
|
||||||
.collect();
|
|
||||||
self.unifier.unify_call(&call, func.custom.unwrap(), sign, &required).map_err(
|
|
||||||
|e| HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]),
|
|
||||||
)?;
|
|
||||||
return Ok(Located {
|
return Ok(Located {
|
||||||
location,
|
location,
|
||||||
custom: Some(sign.ret),
|
custom: Some(sign.ret),
|
||||||
|
@ -89,6 +89,13 @@ pub struct FuncArg {
|
|||||||
pub default_value: Option<SymbolValue>,
|
pub default_value: Option<SymbolValue>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FuncArg {
|
||||||
|
#[must_use]
|
||||||
|
pub fn is_required(&self) -> bool {
|
||||||
|
self.default_value.is_none()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct FunSignature {
|
pub struct FunSignature {
|
||||||
pub args: Vec<FuncArg>,
|
pub args: Vec<FuncArg>,
|
||||||
@ -562,61 +569,153 @@ impl Unifier {
|
|||||||
call: &Call,
|
call: &Call,
|
||||||
b: Type,
|
b: Type,
|
||||||
signature: &FunSignature,
|
signature: &FunSignature,
|
||||||
required: &[StrRef],
|
|
||||||
) -> Result<(), TypeError> {
|
) -> Result<(), TypeError> {
|
||||||
|
/*
|
||||||
|
NOTE: scenarios to consider:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def func1(x: int32, y: int32, z: int32 = 5): pass
|
||||||
|
|
||||||
|
# Normal scenarios
|
||||||
|
func1(23, 45) # OK, z has default
|
||||||
|
func1(23, 45, 67) # OK, z's default is overwritten
|
||||||
|
func1(x = 23, y = 45) # OK, user is using kwargs to set positional args
|
||||||
|
func1(y = 45, x = 23) # OK, kwargs order doesn't matter
|
||||||
|
|
||||||
|
# Error scenarios
|
||||||
|
func1() # ERROR: Missing arguments: x, y
|
||||||
|
func1(23) # ERROR: Missing arguments: y
|
||||||
|
func1(z = 23) # ERROR: Missing arguments: x, y
|
||||||
|
func1(x = 23) # ERROR: Missing arguments: y
|
||||||
|
func1(23, 45, x = 5) # ERROR: Got multiple values for x
|
||||||
|
func1(23, 45, x = 5, y = 6) # ERROR: Got multiple values for x (y too but Python does not report it)
|
||||||
|
func1(23, 45, 67, z = 89) # ERROR: Got multiple values for z
|
||||||
|
func1(23, 45, 67, 89) # ERROR: Function only takes from 2 to 3 positional arguments but 4 were given.
|
||||||
|
func1(23, 45, 67, w = 3) # ERROR: Got an unexpected keyword argument 'w'
|
||||||
|
|
||||||
|
# Error scenarios that do not need to be handled here.
|
||||||
|
func1(23, 45, z = 67, z = 89) # ERROR: Keyword argument repeated: z, the parser panics on this.
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct ParamInfo<'a> {
|
||||||
|
/// Has this parameter been supplied with an argument already?
|
||||||
|
has_been_supplied: bool,
|
||||||
|
/// The corresponding [`FuncArg`] instance of this parameter (for fast table lookups)
|
||||||
|
param: &'a FuncArg,
|
||||||
|
}
|
||||||
|
|
||||||
let snapshot = self.unification_table.get_snapshot();
|
let snapshot = self.unification_table.get_snapshot();
|
||||||
if self.snapshot.is_none() {
|
if self.snapshot.is_none() {
|
||||||
self.snapshot = Some(snapshot);
|
self.snapshot = Some(snapshot);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get details about the function signature/parameters.
|
||||||
|
let num_params = signature.args.len();
|
||||||
|
|
||||||
|
// Force the type vars in `b` and `signature' to be up-to-date.
|
||||||
|
let b = self.instantiate_fun(b, signature);
|
||||||
|
let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() };
|
||||||
|
|
||||||
|
// Get details about the input arguments
|
||||||
let Call { posargs, kwargs, ret, fun, loc } = call;
|
let Call { posargs, kwargs, ret, fun, loc } = call;
|
||||||
let instantiated = self.instantiate_fun(b, signature);
|
let num_args = posargs.len() + kwargs.len();
|
||||||
let r = self.get_ty(instantiated);
|
|
||||||
let r = r.as_ref();
|
// Now we check the arguments against the parameters
|
||||||
let TypeEnum::TFunc(signature) = r else { unreachable!() };
|
|
||||||
// we check to make sure that all required arguments (those without default
|
// Helper lambdas
|
||||||
// arguments) are provided, and do not provide the same argument twice.
|
let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| {
|
||||||
let mut required = required.to_vec();
|
let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok();
|
||||||
let mut all_names: Vec<_> = signature.args.iter().map(|v| (v.name, v.ty)).rev().collect();
|
if ok {
|
||||||
for (i, t) in posargs.iter().enumerate() {
|
Ok(())
|
||||||
if signature.args.len() <= i {
|
} else {
|
||||||
|
// Typecheck failed, throw an error.
|
||||||
self.restore_snapshot();
|
self.restore_snapshot();
|
||||||
return Err(TypeError::new(
|
Err(TypeError::new(
|
||||||
TypeErrorKind::TooManyArguments {
|
TypeErrorKind::IncorrectArgType {
|
||||||
expected: signature.args.len(),
|
name: param_name,
|
||||||
got: posargs.len() + kwargs.len(),
|
expected: expected_arg_ty,
|
||||||
|
got: arg_ty,
|
||||||
},
|
},
|
||||||
*loc,
|
*loc,
|
||||||
));
|
))
|
||||||
}
|
}
|
||||||
required.pop();
|
};
|
||||||
let (name, expected) = all_names.pop().unwrap();
|
|
||||||
self.unify_impl(expected, *t, false).map_err(|_| {
|
// Check for "too many arguments"
|
||||||
self.restore_snapshot();
|
if num_params < posargs.len() {
|
||||||
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
let expected_min_count =
|
||||||
})?;
|
signature.args.iter().filter(|param| param.is_required()).count();
|
||||||
}
|
let expected_max_count = num_params;
|
||||||
for (k, t) in kwargs {
|
|
||||||
if let Some(i) = required.iter().position(|v| v == k) {
|
|
||||||
required.remove(i);
|
|
||||||
}
|
|
||||||
let i = all_names.iter().position(|v| &v.0 == k).ok_or_else(|| {
|
|
||||||
self.restore_snapshot();
|
|
||||||
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
|
|
||||||
})?;
|
|
||||||
let (name, expected) = all_names.remove(i);
|
|
||||||
self.unify_impl(expected, *t, false).map_err(|_| {
|
|
||||||
self.restore_snapshot();
|
|
||||||
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
if !required.is_empty() {
|
|
||||||
self.restore_snapshot();
|
self.restore_snapshot();
|
||||||
return Err(TypeError::new(
|
return Err(TypeError::new(
|
||||||
TypeErrorKind::MissingArgs(required.iter().join(", ")),
|
TypeErrorKind::TooManyArguments {
|
||||||
|
expected_min_count,
|
||||||
|
expected_max_count,
|
||||||
|
got_count: num_args,
|
||||||
|
},
|
||||||
*loc,
|
*loc,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap
|
||||||
|
let mut param_info_by_name: IndexMap<StrRef, ParamInfo> = signature
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg }))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Now consume all positional arguments and typecheck them.
|
||||||
|
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
|
||||||
|
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
|
||||||
|
let param_info = param_info_by_name.get_mut(¶m.name).unwrap();
|
||||||
|
param_info.has_been_supplied = true;
|
||||||
|
|
||||||
|
// Typecheck
|
||||||
|
type_check_arg(param.name, param.ty, arg_ty)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now consume all keyword arguments and typecheck them.
|
||||||
|
for (¶m_name, &arg_ty) in kwargs {
|
||||||
|
// We will also use this opportunity to check if this keyword argument is "legal".
|
||||||
|
|
||||||
|
let Some(param_info) = param_info_by_name.get_mut(¶m_name) else {
|
||||||
|
self.restore_snapshot();
|
||||||
|
return Err(TypeError::new(TypeErrorKind::UnknownArgName(param_name), *loc));
|
||||||
|
};
|
||||||
|
|
||||||
|
if param_info.has_been_supplied {
|
||||||
|
// NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`)
|
||||||
|
// is IMPOSSIBLE as the parser would have already failed.
|
||||||
|
// We only have to care about "got multiple values for XYZ"
|
||||||
|
|
||||||
|
self.restore_snapshot();
|
||||||
|
return Err(TypeError::new(
|
||||||
|
TypeErrorKind::GotMultipleValues { name: param_name },
|
||||||
|
*loc,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
param_info.has_been_supplied = true;
|
||||||
|
|
||||||
|
// Typecheck
|
||||||
|
type_check_arg(param_name, param_info.param.ty, arg_ty)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// After checking posargs and kwargs, check if there are any
|
||||||
|
// unsupported required parameters, and throw an error if they exist.
|
||||||
|
let missing_arg_names = param_info_by_name
|
||||||
|
.values()
|
||||||
|
.filter(|param_info| param_info.param.is_required() && !param_info.has_been_supplied)
|
||||||
|
.map(|param_info| param_info.param.name)
|
||||||
|
.collect_vec();
|
||||||
|
if !missing_arg_names.is_empty() {
|
||||||
|
self.restore_snapshot();
|
||||||
|
return Err(TypeError::new(TypeErrorKind::MissingArgs { missing_arg_names }, *loc));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, check the Call's return type
|
||||||
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
|
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
|
||||||
self.restore_snapshot();
|
self.restore_snapshot();
|
||||||
if err.loc.is_none() {
|
if err.loc.is_none() {
|
||||||
@ -624,7 +723,8 @@ impl Unifier {
|
|||||||
}
|
}
|
||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
*fun.borrow_mut() = Some(instantiated);
|
|
||||||
|
*fun.borrow_mut() = Some(b);
|
||||||
|
|
||||||
self.discard_snapshot(snapshot);
|
self.discard_snapshot(snapshot);
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -990,17 +1090,10 @@ impl Unifier {
|
|||||||
self.unification_table.set_value(b, Rc::new(TCall(calls)));
|
self.unification_table.set_value(b, Rc::new(TCall(calls)));
|
||||||
}
|
}
|
||||||
(TCall(calls), TFunc(signature)) => {
|
(TCall(calls), TFunc(signature)) => {
|
||||||
let required: Vec<StrRef> = signature
|
|
||||||
.args
|
|
||||||
.iter()
|
|
||||||
.filter(|v| v.default_value.is_none())
|
|
||||||
.map(|v| v.name)
|
|
||||||
.rev()
|
|
||||||
.collect();
|
|
||||||
// we unify every calls to the function signature.
|
// we unify every calls to the function signature.
|
||||||
for c in calls {
|
for c in calls {
|
||||||
let call = self.calls[c.0].clone();
|
let call = self.calls[c.0].clone();
|
||||||
self.unify_call(&call, b, signature, &required)?;
|
self.unify_call(&call, b, signature)?;
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user