diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 7d322d6..bc761d7 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -248,17 +248,25 @@ impl<'a> Inferencer<'a> { if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(*ty) { let sign = sign.borrow(); if sign.vars.is_empty() { - let call = self.unifier.add_call(Call { + let call = Call { posargs: params, kwargs: HashMap::new(), ret: sign.ret, fun: RefCell::new(None), - }); - let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); + }; if let Some(ret) = ret { self.unifier.unify(sign.ret, ret).unwrap(); } - self.constrain(call, *ty, &location)?; + let required: Vec<_> = sign + .args + .iter() + .filter(|v| v.default_value.is_none()) + .map(|v| v.name) + .rev() + .collect(); + self.unifier + .unify_call(&call, *ty, &sign, &required) + .map_err(|old| format!("{} at {}", old, location))?; return Ok(sign.ret); } } @@ -491,7 +499,7 @@ impl<'a> Inferencer<'a> { if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) { let sign = sign.borrow(); if sign.vars.is_empty() { - let call = self.unifier.add_call(Call { + let call = Call { posargs: args.iter().map(|v| v.custom.unwrap()).collect(), kwargs: keywords .iter() @@ -499,9 +507,17 @@ impl<'a> Inferencer<'a> { .collect(), fun: RefCell::new(None), ret: sign.ret, - }); - let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); - self.unify(func.custom.unwrap(), call, &func.location)?; + }; + let required: Vec<_> = sign + .args + .iter() + .filter(|v| v.default_value.is_none()) + .map(|v| v.name) + .rev() + .collect(); + self.unifier + .unify_call(&call, func.custom.unwrap(), &sign, &required) + .map_err(|old| format!("{} at {}", old, location))?; return Ok(Located { location, custom: Some(sign.ret), diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index a961aa7..fbbf776 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -374,6 +374,54 @@ impl Unifier { } } + pub fn unify_call(&mut self, call: &Call, b: Type, signature: &FunSignature, required: &[StrRef]) -> Result<(), String> { + let Call { posargs, kwargs, ret, fun } = call; + let instantiated = self.instantiate_fun(b, &*signature); + let r = self.get_ty(instantiated); + let r = r.as_ref(); + let signature; + if let TypeEnum::TFunc(s) = &*r { + signature = s; + } else { + unreachable!(); + } + // we check to make sure that all required arguments (those without default + // arguments) are provided, and do not provide the same argument twice. + let mut required = required.to_vec(); + let mut all_names: Vec<_> = signature + .borrow() + .args + .iter() + .map(|v| (v.name, v.ty)) + .rev() + .collect(); + for (i, t) in posargs.iter().enumerate() { + if signature.borrow().args.len() <= i { + return Err("Too many arguments.".to_string()); + } + if !required.is_empty() { + required.pop(); + } + self.unify(all_names.pop().unwrap().1, *t)?; + } + for (k, t) in kwargs.iter() { + if let Some(i) = required.iter().position(|v| v == k) { + required.remove(i); + } + let i = all_names + .iter() + .position(|v| &v.0 == k) + .ok_or_else(|| format!("Unknown keyword argument {}", k))?; + self.unify(all_names.remove(i).1, *t)?; + } + if !required.is_empty() { + return Err("Expected more arguments".to_string()); + } + self.unify(*ret, signature.borrow().ret)?; + *fun.borrow_mut() = Some(instantiated); + Ok(()) + } + pub fn unify(&mut self, a: Type, b: Type) -> Result<(), String> { if self.unification_table.unioned(a, b) { Ok(()) @@ -574,51 +622,10 @@ impl Unifier { .rev() .collect(); // we unify every calls to the function signature. + let signature = signature.borrow(); for c in calls.borrow().iter() { - let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone(); - let instantiated = self.instantiate_fun(b, &*signature.borrow()); - let r = self.get_ty(instantiated); - let r = r.as_ref(); - let signature; - if let TypeEnum::TFunc(s) = &*r { - signature = s; - } else { - unreachable!(); - } - // we check to make sure that all required arguments (those without default - // arguments) are provided, and do not provide the same argument twice. - let mut required = required.clone(); - let mut all_names: Vec<_> = signature - .borrow() - .args - .iter() - .map(|v| (v.name, v.ty)) - .rev() - .collect(); - for (i, t) in posargs.iter().enumerate() { - if signature.borrow().args.len() <= i { - return Err("Too many arguments.".to_string()); - } - if !required.is_empty() { - required.pop(); - } - self.unify(all_names.pop().unwrap().1, *t)?; - } - for (k, t) in kwargs.iter() { - if let Some(i) = required.iter().position(|v| v == k) { - required.remove(i); - } - let i = all_names - .iter() - .position(|v| &v.0 == k) - .ok_or_else(|| format!("Unknown keyword argument {}", k))?; - self.unify(all_names.remove(i).1, *t)?; - } - if !required.is_empty() { - return Err("Expected more arguments".to_string()); - } - self.unify(*ret, signature.borrow().ret)?; - *fun.borrow_mut() = Some(instantiated); + let call = self.calls[c.0].clone(); + self.unify_call(&call, b, &signature, &required)?; } self.set_a_to_b(a, b); }