forked from M-Labs/nac3
1
0
Fork 0

core: improve function call errors

This commit is contained in:
lyken 2024-06-26 15:55:43 +08:00 committed by sb10q
parent ca8459dc7b
commit 0ec967a468
3 changed files with 170 additions and 74 deletions

View File

@ -4,15 +4,22 @@ use std::fmt::Display;
use crate::typecheck::typedef::TypeEnum; use crate::typecheck::typedef::TypeEnum;
use super::typedef::{RecordKey, Type, Unifier}; use super::typedef::{RecordKey, Type, Unifier};
use itertools::Itertools;
use nac3parser::ast::{Location, StrRef}; use nac3parser::ast::{Location, StrRef};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum TypeErrorKind { pub enum TypeErrorKind {
TooManyArguments { GotMultipleValues {
expected: usize, name: StrRef,
got: usize, },
TooManyArguments {
expected_min_count: usize,
expected_max_count: usize,
got_count: usize,
},
MissingArgs {
missing_arg_names: Vec<StrRef>,
}, },
MissingArgs(String),
UnknownArgName(StrRef), UnknownArgName(StrRef),
IncorrectArgType { IncorrectArgType {
name: StrRef, name: StrRef,
@ -78,10 +85,20 @@ impl<'a> Display for DisplayTypeError<'a> {
use TypeErrorKind::*; use TypeErrorKind::*;
let mut notes = Some(HashMap::new()); let mut notes = Some(HashMap::new());
match &self.err.kind { match &self.err.kind {
TooManyArguments { expected, got } => { GotMultipleValues { name } => {
write!(f, "Too many arguments. Expected {expected} but got {got}") write!(f, "For multiple values for parameter {name}")
} }
MissingArgs(args) => { TooManyArguments { expected_min_count, expected_max_count, got_count } => {
debug_assert!(expected_min_count <= expected_max_count);
if expected_min_count == expected_max_count {
let expected_count = expected_min_count; // or expected_max_count
write!(f, "Too many arguments. Expected {expected_count} but got {got_count}")
} else {
write!(f, "Too many arguments. Expected {expected_min_count} to {expected_max_count} arguments but got {got_count}")
}
}
MissingArgs { missing_arg_names } => {
let args = missing_arg_names.iter().join(", ");
write!(f, "Missing arguments: {args}") write!(f, "Missing arguments: {args}")
} }
UnknownArgName(name) => { UnknownArgName(name) => {
@ -90,7 +107,7 @@ impl<'a> Display for DisplayTypeError<'a> {
IncorrectArgType { name, expected, got } => { IncorrectArgType { name, expected, got } => {
let expected = self.unifier.stringify_with_notes(*expected, &mut notes); let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
let got = self.unifier.stringify_with_notes(*got, &mut notes); let got = self.unifier.stringify_with_notes(*got, &mut notes);
write!(f, "Incorrect argument type for {name}. Expected {expected}, but got {got}") write!(f, "Incorrect argument type for parameter {name}. Expected {expected}, but got {got}")
} }
FieldUnificationError { field, types, loc } => { FieldUnificationError { field, types, loc } => {
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes); let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);

View File

@ -642,14 +642,7 @@ impl<'a> Inferencer<'a> {
}) })
.unwrap(); .unwrap();
} }
let required: Vec<_> = sign self.unifier.unify_call(&call, ty, sign).map_err(|e| {
.args
.iter()
.filter(|v| v.default_value.is_none())
.map(|v| v.name)
.rev()
.collect();
self.unifier.unify_call(&call, ty, sign, &required).map_err(|e| {
HashSet::from([e HashSet::from([e
.at(Some(location)) .at(Some(location))
.to_display(self.unifier) .to_display(self.unifier)
@ -1347,16 +1340,9 @@ impl<'a> Inferencer<'a> {
ret: sign.ret, ret: sign.ret,
loc: Some(location), loc: Some(location),
}; };
let required: Vec<_> = sign self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
.args HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
.iter() })?;
.filter(|v| v.default_value.is_none())
.map(|v| v.name)
.rev()
.collect();
self.unifier.unify_call(&call, func.custom.unwrap(), sign, &required).map_err(
|e| HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]),
)?;
return Ok(Located { return Ok(Located {
location, location,
custom: Some(sign.ret), custom: Some(sign.ret),

View File

@ -89,6 +89,13 @@ pub struct FuncArg {
pub default_value: Option<SymbolValue>, pub default_value: Option<SymbolValue>,
} }
impl FuncArg {
#[must_use]
pub fn is_required(&self) -> bool {
self.default_value.is_none()
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct FunSignature { pub struct FunSignature {
pub args: Vec<FuncArg>, pub args: Vec<FuncArg>,
@ -562,61 +569,153 @@ impl Unifier {
call: &Call, call: &Call,
b: Type, b: Type,
signature: &FunSignature, signature: &FunSignature,
required: &[StrRef],
) -> Result<(), TypeError> { ) -> Result<(), TypeError> {
/*
NOTE: scenarios to consider:
```python
def func1(x: int32, y: int32, z: int32 = 5): pass
# Normal scenarios
func1(23, 45) # OK, z has default
func1(23, 45, 67) # OK, z's default is overwritten
func1(x = 23, y = 45) # OK, user is using kwargs to set positional args
func1(y = 45, x = 23) # OK, kwargs order doesn't matter
# Error scenarios
func1() # ERROR: Missing arguments: x, y
func1(23) # ERROR: Missing arguments: y
func1(z = 23) # ERROR: Missing arguments: x, y
func1(x = 23) # ERROR: Missing arguments: y
func1(23, 45, x = 5) # ERROR: Got multiple values for x
func1(23, 45, x = 5, y = 6) # ERROR: Got multiple values for x (y too but Python does not report it)
func1(23, 45, 67, z = 89) # ERROR: Got multiple values for z
func1(23, 45, 67, 89) # ERROR: Function only takes from 2 to 3 positional arguments but 4 were given.
func1(23, 45, 67, w = 3) # ERROR: Got an unexpected keyword argument 'w'
# Error scenarios that do not need to be handled here.
func1(23, 45, z = 67, z = 89) # ERROR: Keyword argument repeated: z, the parser panics on this.
```
*/
struct ParamInfo<'a> {
/// Has this parameter been supplied with an argument already?
has_been_supplied: bool,
/// The corresponding [`FuncArg`] instance of this parameter (for fast table lookups)
param: &'a FuncArg,
}
let snapshot = self.unification_table.get_snapshot(); let snapshot = self.unification_table.get_snapshot();
if self.snapshot.is_none() { if self.snapshot.is_none() {
self.snapshot = Some(snapshot); self.snapshot = Some(snapshot);
} }
// Get details about the function signature/parameters.
let num_params = signature.args.len();
// Force the type vars in `b` and `signature' to be up-to-date.
let b = self.instantiate_fun(b, signature);
let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() };
// Get details about the input arguments
let Call { posargs, kwargs, ret, fun, loc } = call; let Call { posargs, kwargs, ret, fun, loc } = call;
let instantiated = self.instantiate_fun(b, signature); let num_args = posargs.len() + kwargs.len();
let r = self.get_ty(instantiated);
let r = r.as_ref(); // Now we check the arguments against the parameters
let TypeEnum::TFunc(signature) = r else { unreachable!() };
// we check to make sure that all required arguments (those without default // Helper lambdas
// arguments) are provided, and do not provide the same argument twice. let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| {
let mut required = required.to_vec(); let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok();
let mut all_names: Vec<_> = signature.args.iter().map(|v| (v.name, v.ty)).rev().collect(); if ok {
for (i, t) in posargs.iter().enumerate() { Ok(())
if signature.args.len() <= i { } else {
// Typecheck failed, throw an error.
self.restore_snapshot(); self.restore_snapshot();
return Err(TypeError::new( Err(TypeError::new(
TypeErrorKind::TooManyArguments { TypeErrorKind::IncorrectArgType {
expected: signature.args.len(), name: param_name,
got: posargs.len() + kwargs.len(), expected: expected_arg_ty,
got: arg_ty,
}, },
*loc, *loc,
)); ))
} }
required.pop(); };
let (name, expected) = all_names.pop().unwrap();
self.unify_impl(expected, *t, false).map_err(|_| { // Check for "too many arguments"
self.restore_snapshot(); if num_params < posargs.len() {
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) let expected_min_count =
})?; signature.args.iter().filter(|param| param.is_required()).count();
} let expected_max_count = num_params;
for (k, t) in kwargs {
if let Some(i) = required.iter().position(|v| v == k) {
required.remove(i);
}
let i = all_names.iter().position(|v| &v.0 == k).ok_or_else(|| {
self.restore_snapshot();
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
})?;
let (name, expected) = all_names.remove(i);
self.unify_impl(expected, *t, false).map_err(|_| {
self.restore_snapshot();
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
})?;
}
if !required.is_empty() {
self.restore_snapshot(); self.restore_snapshot();
return Err(TypeError::new( return Err(TypeError::new(
TypeErrorKind::MissingArgs(required.iter().join(", ")), TypeErrorKind::TooManyArguments {
expected_min_count,
expected_max_count,
got_count: num_args,
},
*loc, *loc,
)); ));
} }
// NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap
let mut param_info_by_name: IndexMap<StrRef, ParamInfo> = signature
.args
.iter()
.map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg }))
.collect();
// Now consume all positional arguments and typecheck them.
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
let param_info = param_info_by_name.get_mut(&param.name).unwrap();
param_info.has_been_supplied = true;
// Typecheck
type_check_arg(param.name, param.ty, arg_ty)?;
}
// Now consume all keyword arguments and typecheck them.
for (&param_name, &arg_ty) in kwargs {
// We will also use this opportunity to check if this keyword argument is "legal".
let Some(param_info) = param_info_by_name.get_mut(&param_name) else {
self.restore_snapshot();
return Err(TypeError::new(TypeErrorKind::UnknownArgName(param_name), *loc));
};
if param_info.has_been_supplied {
// NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`)
// is IMPOSSIBLE as the parser would have already failed.
// We only have to care about "got multiple values for XYZ"
self.restore_snapshot();
return Err(TypeError::new(
TypeErrorKind::GotMultipleValues { name: param_name },
*loc,
));
}
param_info.has_been_supplied = true;
// Typecheck
type_check_arg(param_name, param_info.param.ty, arg_ty)?;
}
// After checking posargs and kwargs, check if there are any
// unsupported required parameters, and throw an error if they exist.
let missing_arg_names = param_info_by_name
.values()
.filter(|param_info| param_info.param.is_required() && !param_info.has_been_supplied)
.map(|param_info| param_info.param.name)
.collect_vec();
if !missing_arg_names.is_empty() {
self.restore_snapshot();
return Err(TypeError::new(TypeErrorKind::MissingArgs { missing_arg_names }, *loc));
}
// Finally, check the Call's return type
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| { self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
self.restore_snapshot(); self.restore_snapshot();
if err.loc.is_none() { if err.loc.is_none() {
@ -624,7 +723,8 @@ impl Unifier {
} }
err err
})?; })?;
*fun.borrow_mut() = Some(instantiated);
*fun.borrow_mut() = Some(b);
self.discard_snapshot(snapshot); self.discard_snapshot(snapshot);
Ok(()) Ok(())
@ -990,17 +1090,10 @@ impl Unifier {
self.unification_table.set_value(b, Rc::new(TCall(calls))); self.unification_table.set_value(b, Rc::new(TCall(calls)));
} }
(TCall(calls), TFunc(signature)) => { (TCall(calls), TFunc(signature)) => {
let required: Vec<StrRef> = signature
.args
.iter()
.filter(|v| v.default_value.is_none())
.map(|v| v.name)
.rev()
.collect();
// we unify every calls to the function signature. // we unify every calls to the function signature.
for c in calls { for c in calls {
let call = self.calls[c.0].clone(); let call = self.calls[c.0].clone();
self.unify_call(&call, b, signature, &required)?; self.unify_call(&call, b, signature)?;
} }
self.set_a_to_b(a, b); self.set_a_to_b(a, b);
} }