diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 8c1f68be9..f6a61293b 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -238,8 +238,32 @@ impl<'a> Inferencer<'a> { method: String, obj: Type, params: Vec, - ret: Type, + ret: Option, ) -> InferenceResult { + if let TypeEnum::TObj { params: class_params, fields, .. } = &*self.unifier.get_ty(obj) { + if class_params.borrow().is_empty() { + if let Some(ty) = fields.borrow().get(&method) { + if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(*ty) { + let sign = sign.borrow(); + if sign.vars.is_empty() { + let call = self.unifier.add_call(Call { + posargs: params, + kwargs: HashMap::new(), + fun: RefCell::new(Some(*ty)), + ret: sign.ret, + }); + if let Some(ret) = ret { + self.unifier.unify(sign.ret, ret).unwrap(); + } + self.calls.insert(location.into(), call); + return Ok(sign.ret) + } + } + } + } + } + let ret = ret.unwrap_or_else(|| self.unifier.get_fresh_var().0); + let call = self.unifier.add_call(Call { posargs: params, kwargs: HashMap::new(), @@ -460,6 +484,24 @@ impl<'a> Inferencer<'a> { .into_iter() .map(|v| fold::fold_keyword(self, v)) .collect::, _>>()?; + + if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) { + let sign = sign.borrow(); + if sign.vars.is_empty() { + let call = self.unifier.add_call(Call { + posargs: args.iter().map(|v| v.custom.unwrap()).collect(), + kwargs: keywords + .iter() + .map(|v| (v.node.arg.as_ref().unwrap().clone(), v.custom.unwrap())) + .collect(), + fun: RefCell::new(func.custom), + ret: sign.ret, + }); + self.calls.insert(location.into(), call); + return Ok(Located { location, custom: Some(sign.ret), node: ExprKind::Call { func, args, keywords } }) + } + } + let ret = self.unifier.get_fresh_var().0; let call = self.unifier.add_call(Call { posargs: args.iter().map(|v| v.custom.unwrap()).collect(), @@ -553,13 +595,12 @@ impl<'a> Inferencer<'a> { right: &ast::Expr>, ) -> InferenceResult { let method = binop_name(op); - let ret = self.unifier.get_fresh_var().0; self.build_method_call( location, method.to_string(), left.custom.unwrap(), vec![right.custom.unwrap()], - ret, + None, ) } @@ -569,13 +610,12 @@ impl<'a> Inferencer<'a> { operand: &ast::Expr>, ) -> InferenceResult { let method = unaryop_name(op); - let ret = self.unifier.get_fresh_var().0; self.build_method_call( operand.location, method.to_string(), operand.custom.unwrap(), vec![], - ret, + None, ) } @@ -594,7 +634,7 @@ impl<'a> Inferencer<'a> { method, a.custom.unwrap(), vec![b.custom.unwrap()], - boolean, + Some(boolean), )?; } Ok(boolean)