diff --git a/nac3core/src/typecheck/type_inferencer.rs b/nac3core/src/typecheck/type_inferencer.rs index 01289c3a..77cccf22 100644 --- a/nac3core/src/typecheck/type_inferencer.rs +++ b/nac3core/src/typecheck/type_inferencer.rs @@ -6,9 +6,13 @@ use std::rc::Rc; use super::magic_methods::*; use super::symbol_resolver::{SymbolResolver, SymbolType}; -use super::typedef::{Call, Type, TypeEnum, Unifier}; +use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}; use itertools::izip; -use rustpython_parser::ast::{self, fold}; +use rustpython_parser::ast::{ + self, + fold::{self, Fold}, + Arguments, Expr, ExprKind, Located, Location, +}; pub struct PrimitiveStore { int32: Type, @@ -21,7 +25,7 @@ pub struct PrimitiveStore { pub struct Inferencer<'a> { resolver: &'a mut Box, unifier: &'a mut Unifier, - variable_mapping: &'a mut HashMap, + variable_mapping: HashMap, calls: &'a mut Vec>, primitives: &'a PrimitiveStore, } @@ -35,10 +39,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { } fn fold_expr(&mut self, node: ast::Expr<()>) -> Result, Self::Error> { - let expr = match &node.node { - ast::ExprKind::Call { .. } => unimplemented!(), - ast::ExprKind::Lambda { .. } => unimplemented!(), - ast::ExprKind::ListComp { .. } => unimplemented!(), + let expr = match node.node { + ast::ExprKind::Call { + func, + args, + keywords, + } => unimplemented!(), + ast::ExprKind::Lambda { args, body } => { + self.fold_lambda(node.location, *args, *body)? + } + ast::ExprKind::ListComp { elt, generators } => unimplemented!(), _ => fold::fold_expr(self, node)?, }; let custom = match &expr.node { @@ -59,11 +69,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { ops, comparators, } => Some(self.infer_compare(left, ops, comparators)?), - ast::ExprKind::Call { - func, - args, - keywords, - } => unimplemented!(), + ast::ExprKind::Call { .. } => expr.custom, ast::ExprKind::Subscript { value, slice, @@ -117,6 +123,69 @@ impl<'a> Inferencer<'a> { Ok(ret) } + fn fold_lambda( + &mut self, + location: Location, + args: Arguments, + body: ast::Expr<()>, + ) -> Result>, String> { + if !args.posonlyargs.is_empty() + || args.vararg.is_some() + || !args.kwonlyargs.is_empty() + || args.kwarg.is_some() + || !args.defaults.is_empty() + { + // actually I'm not sure whether programs violating this is a valid python program. + return Err( + "We only support positional or keyword arguments without defaults for lambdas." + .to_string(), + ); + } + + let fn_args: Vec<_> = args + .args + .iter() + .map(|v| (v.node.arg.clone(), self.unifier.get_fresh_var().0)) + .collect(); + let mut variable_mapping = self.variable_mapping.clone(); + variable_mapping.extend(fn_args.iter().cloned()); + let ret = self.unifier.get_fresh_var().0; + let mut new_context = Inferencer { + resolver: self.resolver, + unifier: self.unifier, + variable_mapping, + calls: self.calls, + primitives: self.primitives, + }; + let fun = FunSignature { + args: fn_args + .iter() + .map(|(k, ty)| FuncArg { + name: k.clone(), + ty: *ty, + is_optional: false, + }) + .collect(), + ret, + vars: Default::default(), + }; + let body = new_context.fold_expr(body)?; + new_context.unifier.unify(fun.ret, body.custom.unwrap())?; + let mut args = new_context.fold_arguments(args)?; + for (arg, (name, ty)) in args.args.iter_mut().zip(fn_args.iter()) { + assert_eq!(&arg.node.arg, name); + arg.custom = Some(*ty); + } + Ok(Located { + location, + node: ExprKind::Lambda { + args: args.into(), + body: body.into(), + }, + custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun))), + }) + } + fn infer_identifier(&mut self, id: &str) -> InferenceResult { if let Some(ty) = self.variable_mapping.get(id) { Ok(*ty)