diff --git a/nac3core/src/typecheck/type_error.rs b/nac3core/src/typecheck/type_error.rs index e6d7d73..d03b3f1 100644 --- a/nac3core/src/typecheck/type_error.rs +++ b/nac3core/src/typecheck/type_error.rs @@ -4,15 +4,22 @@ use std::fmt::Display; use crate::typecheck::typedef::TypeEnum; use super::typedef::{RecordKey, Type, Unifier}; +use itertools::Itertools; use nac3parser::ast::{Location, StrRef}; #[derive(Debug, Clone)] pub enum TypeErrorKind { - TooManyArguments { - expected: usize, - got: usize, + GotMultipleValues { + name: StrRef, + }, + TooManyArguments { + expected_min_count: usize, + expected_max_count: usize, + got_count: usize, + }, + MissingArgs { + missing_arg_names: Vec, }, - MissingArgs(String), UnknownArgName(StrRef), IncorrectArgType { name: StrRef, @@ -78,10 +85,20 @@ impl<'a> Display for DisplayTypeError<'a> { use TypeErrorKind::*; let mut notes = Some(HashMap::new()); match &self.err.kind { - TooManyArguments { expected, got } => { - write!(f, "Too many arguments. Expected {expected} but got {got}") + GotMultipleValues { name } => { + write!(f, "For multiple values for parameter {name}") } - MissingArgs(args) => { + TooManyArguments { expected_min_count, expected_max_count, got_count } => { + debug_assert!(expected_min_count <= expected_max_count); + if expected_min_count == expected_max_count { + let expected_count = expected_min_count; // or expected_max_count + write!(f, "Too many arguments. Expected {expected_count} but got {got_count}") + } else { + write!(f, "Too many arguments. Expected {expected_min_count} to {expected_max_count} arguments but got {got_count}") + } + } + MissingArgs { missing_arg_names } => { + let args = missing_arg_names.iter().join(", "); write!(f, "Missing arguments: {args}") } UnknownArgName(name) => { @@ -90,7 +107,7 @@ impl<'a> Display for DisplayTypeError<'a> { IncorrectArgType { name, expected, got } => { let expected = self.unifier.stringify_with_notes(*expected, &mut notes); let got = self.unifier.stringify_with_notes(*got, &mut notes); - write!(f, "Incorrect argument type for {name}. Expected {expected}, but got {got}") + write!(f, "Incorrect argument type for parameter {name}. Expected {expected}, but got {got}") } FieldUnificationError { field, types, loc } => { let lhs = self.unifier.stringify_with_notes(types.0, &mut notes); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index a3366c9..2e14d15 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -642,14 +642,7 @@ impl<'a> Inferencer<'a> { }) .unwrap(); } - let required: Vec<_> = sign - .args - .iter() - .filter(|v| v.default_value.is_none()) - .map(|v| v.name) - .rev() - .collect(); - self.unifier.unify_call(&call, ty, sign, &required).map_err(|e| { + self.unifier.unify_call(&call, ty, sign).map_err(|e| { HashSet::from([e .at(Some(location)) .to_display(self.unifier) @@ -1347,16 +1340,9 @@ impl<'a> Inferencer<'a> { ret: sign.ret, loc: Some(location), }; - let required: Vec<_> = sign - .args - .iter() - .filter(|v| v.default_value.is_none()) - .map(|v| v.name) - .rev() - .collect(); - self.unifier.unify_call(&call, func.custom.unwrap(), sign, &required).map_err( - |e| HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]), - )?; + self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| { + HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]) + })?; return Ok(Located { location, custom: Some(sign.ret), diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index fc6952f..f041679 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -89,6 +89,13 @@ pub struct FuncArg { pub default_value: Option, } +impl FuncArg { + #[must_use] + pub fn is_required(&self) -> bool { + self.default_value.is_none() + } +} + #[derive(Debug, Clone)] pub struct FunSignature { pub args: Vec, @@ -562,61 +569,153 @@ impl Unifier { call: &Call, b: Type, signature: &FunSignature, - required: &[StrRef], ) -> Result<(), TypeError> { + /* + NOTE: scenarios to consider: + + ```python + def func1(x: int32, y: int32, z: int32 = 5): pass + + # Normal scenarios + func1(23, 45) # OK, z has default + func1(23, 45, 67) # OK, z's default is overwritten + func1(x = 23, y = 45) # OK, user is using kwargs to set positional args + func1(y = 45, x = 23) # OK, kwargs order doesn't matter + + # Error scenarios + func1() # ERROR: Missing arguments: x, y + func1(23) # ERROR: Missing arguments: y + func1(z = 23) # ERROR: Missing arguments: x, y + func1(x = 23) # ERROR: Missing arguments: y + func1(23, 45, x = 5) # ERROR: Got multiple values for x + func1(23, 45, x = 5, y = 6) # ERROR: Got multiple values for x (y too but Python does not report it) + func1(23, 45, 67, z = 89) # ERROR: Got multiple values for z + func1(23, 45, 67, 89) # ERROR: Function only takes from 2 to 3 positional arguments but 4 were given. + func1(23, 45, 67, w = 3) # ERROR: Got an unexpected keyword argument 'w' + + # Error scenarios that do not need to be handled here. + func1(23, 45, z = 67, z = 89) # ERROR: Keyword argument repeated: z, the parser panics on this. + ``` + */ + + struct ParamInfo<'a> { + /// Has this parameter been supplied with an argument already? + has_been_supplied: bool, + /// The corresponding [`FuncArg`] instance of this parameter (for fast table lookups) + param: &'a FuncArg, + } + let snapshot = self.unification_table.get_snapshot(); if self.snapshot.is_none() { self.snapshot = Some(snapshot); } + // Get details about the function signature/parameters. + let num_params = signature.args.len(); + + // Force the type vars in `b` and `signature' to be up-to-date. + let b = self.instantiate_fun(b, signature); + let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() }; + + // Get details about the input arguments let Call { posargs, kwargs, ret, fun, loc } = call; - let instantiated = self.instantiate_fun(b, signature); - let r = self.get_ty(instantiated); - let r = r.as_ref(); - let TypeEnum::TFunc(signature) = r else { unreachable!() }; - // we check to make sure that all required arguments (those without default - // arguments) are provided, and do not provide the same argument twice. - let mut required = required.to_vec(); - let mut all_names: Vec<_> = signature.args.iter().map(|v| (v.name, v.ty)).rev().collect(); - for (i, t) in posargs.iter().enumerate() { - if signature.args.len() <= i { + let num_args = posargs.len() + kwargs.len(); + + // Now we check the arguments against the parameters + + // Helper lambdas + let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| { + let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok(); + if ok { + Ok(()) + } else { + // Typecheck failed, throw an error. self.restore_snapshot(); - return Err(TypeError::new( - TypeErrorKind::TooManyArguments { - expected: signature.args.len(), - got: posargs.len() + kwargs.len(), + Err(TypeError::new( + TypeErrorKind::IncorrectArgType { + name: param_name, + expected: expected_arg_ty, + got: arg_ty, }, *loc, - )); + )) } - required.pop(); - let (name, expected) = all_names.pop().unwrap(); - self.unify_impl(expected, *t, false).map_err(|_| { - self.restore_snapshot(); - TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) - })?; - } - for (k, t) in kwargs { - if let Some(i) = required.iter().position(|v| v == k) { - required.remove(i); - } - let i = all_names.iter().position(|v| &v.0 == k).ok_or_else(|| { - self.restore_snapshot(); - TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc) - })?; - let (name, expected) = all_names.remove(i); - self.unify_impl(expected, *t, false).map_err(|_| { - self.restore_snapshot(); - TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc) - })?; - } - if !required.is_empty() { + }; + + // Check for "too many arguments" + if num_params < posargs.len() { + let expected_min_count = + signature.args.iter().filter(|param| param.is_required()).count(); + let expected_max_count = num_params; + self.restore_snapshot(); return Err(TypeError::new( - TypeErrorKind::MissingArgs(required.iter().join(", ")), + TypeErrorKind::TooManyArguments { + expected_min_count, + expected_max_count, + got_count: num_args, + }, *loc, )); } + + // NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap + let mut param_info_by_name: IndexMap = signature + .args + .iter() + .map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg })) + .collect(); + + // Now consume all positional arguments and typecheck them. + for (&arg_ty, param) in zip(posargs, signature.args.iter()) { + // We will also use this opportunity to mark the corresponding `param_info` as having been supplied. + let param_info = param_info_by_name.get_mut(¶m.name).unwrap(); + param_info.has_been_supplied = true; + + // Typecheck + type_check_arg(param.name, param.ty, arg_ty)?; + } + + // Now consume all keyword arguments and typecheck them. + for (¶m_name, &arg_ty) in kwargs { + // We will also use this opportunity to check if this keyword argument is "legal". + + let Some(param_info) = param_info_by_name.get_mut(¶m_name) else { + self.restore_snapshot(); + return Err(TypeError::new(TypeErrorKind::UnknownArgName(param_name), *loc)); + }; + + if param_info.has_been_supplied { + // NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`) + // is IMPOSSIBLE as the parser would have already failed. + // We only have to care about "got multiple values for XYZ" + + self.restore_snapshot(); + return Err(TypeError::new( + TypeErrorKind::GotMultipleValues { name: param_name }, + *loc, + )); + } + + param_info.has_been_supplied = true; + + // Typecheck + type_check_arg(param_name, param_info.param.ty, arg_ty)?; + } + + // After checking posargs and kwargs, check if there are any + // unsupported required parameters, and throw an error if they exist. + let missing_arg_names = param_info_by_name + .values() + .filter(|param_info| param_info.param.is_required() && !param_info.has_been_supplied) + .map(|param_info| param_info.param.name) + .collect_vec(); + if !missing_arg_names.is_empty() { + self.restore_snapshot(); + return Err(TypeError::new(TypeErrorKind::MissingArgs { missing_arg_names }, *loc)); + } + + // Finally, check the Call's return type self.unify_impl(*ret, signature.ret, false).map_err(|mut err| { self.restore_snapshot(); if err.loc.is_none() { @@ -624,7 +723,8 @@ impl Unifier { } err })?; - *fun.borrow_mut() = Some(instantiated); + + *fun.borrow_mut() = Some(b); self.discard_snapshot(snapshot); Ok(()) @@ -990,17 +1090,10 @@ impl Unifier { self.unification_table.set_value(b, Rc::new(TCall(calls))); } (TCall(calls), TFunc(signature)) => { - let required: Vec = signature - .args - .iter() - .filter(|v| v.default_value.is_none()) - .map(|v| v.name) - .rev() - .collect(); // we unify every calls to the function signature. for c in calls { let call = self.calls[c.0].clone(); - self.unify_call(&call, b, signature, &required)?; + self.unify_call(&call, b, signature)?; } self.set_a_to_b(a, b); }