core: improve function call errors

This commit is contained in:
lyken 2024-06-26 15:55:43 +08:00 committed by sb10q
parent ca8459dc7b
commit 0ec967a468
3 changed files with 170 additions and 74 deletions

View File

@ -4,15 +4,22 @@ use std::fmt::Display;
use crate::typecheck::typedef::TypeEnum;
use super::typedef::{RecordKey, Type, Unifier};
use itertools::Itertools;
use nac3parser::ast::{Location, StrRef};
#[derive(Debug, Clone)]
pub enum TypeErrorKind {
TooManyArguments {
expected: usize,
got: usize,
GotMultipleValues {
name: StrRef,
},
TooManyArguments {
expected_min_count: usize,
expected_max_count: usize,
got_count: usize,
},
MissingArgs {
missing_arg_names: Vec<StrRef>,
},
MissingArgs(String),
UnknownArgName(StrRef),
IncorrectArgType {
name: StrRef,
@ -78,10 +85,20 @@ impl<'a> Display for DisplayTypeError<'a> {
use TypeErrorKind::*;
let mut notes = Some(HashMap::new());
match &self.err.kind {
TooManyArguments { expected, got } => {
write!(f, "Too many arguments. Expected {expected} but got {got}")
GotMultipleValues { name } => {
write!(f, "For multiple values for parameter {name}")
}
MissingArgs(args) => {
TooManyArguments { expected_min_count, expected_max_count, got_count } => {
debug_assert!(expected_min_count <= expected_max_count);
if expected_min_count == expected_max_count {
let expected_count = expected_min_count; // or expected_max_count
write!(f, "Too many arguments. Expected {expected_count} but got {got_count}")
} else {
write!(f, "Too many arguments. Expected {expected_min_count} to {expected_max_count} arguments but got {got_count}")
}
}
MissingArgs { missing_arg_names } => {
let args = missing_arg_names.iter().join(", ");
write!(f, "Missing arguments: {args}")
}
UnknownArgName(name) => {
@ -90,7 +107,7 @@ impl<'a> Display for DisplayTypeError<'a> {
IncorrectArgType { name, expected, got } => {
let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
let got = self.unifier.stringify_with_notes(*got, &mut notes);
write!(f, "Incorrect argument type for {name}. Expected {expected}, but got {got}")
write!(f, "Incorrect argument type for parameter {name}. Expected {expected}, but got {got}")
}
FieldUnificationError { field, types, loc } => {
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);

View File

@ -642,14 +642,7 @@ impl<'a> Inferencer<'a> {
})
.unwrap();
}
let required: Vec<_> = sign
.args
.iter()
.filter(|v| v.default_value.is_none())
.map(|v| v.name)
.rev()
.collect();
self.unifier.unify_call(&call, ty, sign, &required).map_err(|e| {
self.unifier.unify_call(&call, ty, sign).map_err(|e| {
HashSet::from([e
.at(Some(location))
.to_display(self.unifier)
@ -1347,16 +1340,9 @@ impl<'a> Inferencer<'a> {
ret: sign.ret,
loc: Some(location),
};
let required: Vec<_> = sign
.args
.iter()
.filter(|v| v.default_value.is_none())
.map(|v| v.name)
.rev()
.collect();
self.unifier.unify_call(&call, func.custom.unwrap(), sign, &required).map_err(
|e| HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]),
)?;
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
})?;
return Ok(Located {
location,
custom: Some(sign.ret),

View File

@ -89,6 +89,13 @@ pub struct FuncArg {
pub default_value: Option<SymbolValue>,
}
impl FuncArg {
#[must_use]
pub fn is_required(&self) -> bool {
self.default_value.is_none()
}
}
#[derive(Debug, Clone)]
pub struct FunSignature {
pub args: Vec<FuncArg>,
@ -562,61 +569,153 @@ impl Unifier {
call: &Call,
b: Type,
signature: &FunSignature,
required: &[StrRef],
) -> Result<(), TypeError> {
/*
NOTE: scenarios to consider:
```python
def func1(x: int32, y: int32, z: int32 = 5): pass
# Normal scenarios
func1(23, 45) # OK, z has default
func1(23, 45, 67) # OK, z's default is overwritten
func1(x = 23, y = 45) # OK, user is using kwargs to set positional args
func1(y = 45, x = 23) # OK, kwargs order doesn't matter
# Error scenarios
func1() # ERROR: Missing arguments: x, y
func1(23) # ERROR: Missing arguments: y
func1(z = 23) # ERROR: Missing arguments: x, y
func1(x = 23) # ERROR: Missing arguments: y
func1(23, 45, x = 5) # ERROR: Got multiple values for x
func1(23, 45, x = 5, y = 6) # ERROR: Got multiple values for x (y too but Python does not report it)
func1(23, 45, 67, z = 89) # ERROR: Got multiple values for z
func1(23, 45, 67, 89) # ERROR: Function only takes from 2 to 3 positional arguments but 4 were given.
func1(23, 45, 67, w = 3) # ERROR: Got an unexpected keyword argument 'w'
# Error scenarios that do not need to be handled here.
func1(23, 45, z = 67, z = 89) # ERROR: Keyword argument repeated: z, the parser panics on this.
```
*/
struct ParamInfo<'a> {
/// Has this parameter been supplied with an argument already?
has_been_supplied: bool,
/// The corresponding [`FuncArg`] instance of this parameter (for fast table lookups)
param: &'a FuncArg,
}
let snapshot = self.unification_table.get_snapshot();
if self.snapshot.is_none() {
self.snapshot = Some(snapshot);
}
// Get details about the function signature/parameters.
let num_params = signature.args.len();
// Force the type vars in `b` and `signature' to be up-to-date.
let b = self.instantiate_fun(b, signature);
let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() };
// Get details about the input arguments
let Call { posargs, kwargs, ret, fun, loc } = call;
let instantiated = self.instantiate_fun(b, signature);
let r = self.get_ty(instantiated);
let r = r.as_ref();
let TypeEnum::TFunc(signature) = r else { unreachable!() };
// we check to make sure that all required arguments (those without default
// arguments) are provided, and do not provide the same argument twice.
let mut required = required.to_vec();
let mut all_names: Vec<_> = signature.args.iter().map(|v| (v.name, v.ty)).rev().collect();
for (i, t) in posargs.iter().enumerate() {
if signature.args.len() <= i {
let num_args = posargs.len() + kwargs.len();
// Now we check the arguments against the parameters
// Helper lambdas
let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| {
let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok();
if ok {
Ok(())
} else {
// Typecheck failed, throw an error.
self.restore_snapshot();
Err(TypeError::new(
TypeErrorKind::IncorrectArgType {
name: param_name,
expected: expected_arg_ty,
got: arg_ty,
},
*loc,
))
}
};
// Check for "too many arguments"
if num_params < posargs.len() {
let expected_min_count =
signature.args.iter().filter(|param| param.is_required()).count();
let expected_max_count = num_params;
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::TooManyArguments {
expected: signature.args.len(),
got: posargs.len() + kwargs.len(),
expected_min_count,
expected_max_count,
got_count: num_args,
},
*loc,
));
}
required.pop();
let (name, expected) = all_names.pop().unwrap();
self.unify_impl(expected, *t, false).map_err(|_| {
self.restore_snapshot();
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
})?;
// NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap
let mut param_info_by_name: IndexMap<StrRef, ParamInfo> = signature
.args
.iter()
.map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg }))
.collect();
// Now consume all positional arguments and typecheck them.
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
let param_info = param_info_by_name.get_mut(&param.name).unwrap();
param_info.has_been_supplied = true;
// Typecheck
type_check_arg(param.name, param.ty, arg_ty)?;
}
for (k, t) in kwargs {
if let Some(i) = required.iter().position(|v| v == k) {
required.remove(i);
}
let i = all_names.iter().position(|v| &v.0 == k).ok_or_else(|| {
// Now consume all keyword arguments and typecheck them.
for (&param_name, &arg_ty) in kwargs {
// We will also use this opportunity to check if this keyword argument is "legal".
let Some(param_info) = param_info_by_name.get_mut(&param_name) else {
self.restore_snapshot();
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
})?;
let (name, expected) = all_names.remove(i);
self.unify_impl(expected, *t, false).map_err(|_| {
self.restore_snapshot();
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
})?;
}
if !required.is_empty() {
return Err(TypeError::new(TypeErrorKind::UnknownArgName(param_name), *loc));
};
if param_info.has_been_supplied {
// NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`)
// is IMPOSSIBLE as the parser would have already failed.
// We only have to care about "got multiple values for XYZ"
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::MissingArgs(required.iter().join(", ")),
TypeErrorKind::GotMultipleValues { name: param_name },
*loc,
));
}
param_info.has_been_supplied = true;
// Typecheck
type_check_arg(param_name, param_info.param.ty, arg_ty)?;
}
// After checking posargs and kwargs, check if there are any
// unsupported required parameters, and throw an error if they exist.
let missing_arg_names = param_info_by_name
.values()
.filter(|param_info| param_info.param.is_required() && !param_info.has_been_supplied)
.map(|param_info| param_info.param.name)
.collect_vec();
if !missing_arg_names.is_empty() {
self.restore_snapshot();
return Err(TypeError::new(TypeErrorKind::MissingArgs { missing_arg_names }, *loc));
}
// Finally, check the Call's return type
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
self.restore_snapshot();
if err.loc.is_none() {
@ -624,7 +723,8 @@ impl Unifier {
}
err
})?;
*fun.borrow_mut() = Some(instantiated);
*fun.borrow_mut() = Some(b);
self.discard_snapshot(snapshot);
Ok(())
@ -990,17 +1090,10 @@ impl Unifier {
self.unification_table.set_value(b, Rc::new(TCall(calls)));
}
(TCall(calls), TFunc(signature)) => {
let required: Vec<StrRef> = signature
.args
.iter()
.filter(|v| v.default_value.is_none())
.map(|v| v.name)
.rev()
.collect();
// we unify every calls to the function signature.
for c in calls {
let call = self.calls[c.0].clone();
self.unify_call(&call, b, signature, &required)?;
self.unify_call(&call, b, signature)?;
}
self.set_a_to_b(a, b);
}