From 22455e43ac7a764adb9f03e97f0ccb9b7b48c3e3 Mon Sep 17 00:00:00 2001
From: pca006132 <john.lck40@gmail.com>
Date: Tue, 20 Jul 2021 11:34:32 +0800
Subject: [PATCH] lambda fold

---
 nac3core/src/typecheck/type_inferencer.rs | 93 ++++++++++++++++++++---
 1 file changed, 81 insertions(+), 12 deletions(-)

diff --git a/nac3core/src/typecheck/type_inferencer.rs b/nac3core/src/typecheck/type_inferencer.rs
index 01289c3a..77cccf22 100644
--- a/nac3core/src/typecheck/type_inferencer.rs
+++ b/nac3core/src/typecheck/type_inferencer.rs
@@ -6,9 +6,13 @@ use std::rc::Rc;
 
 use super::magic_methods::*;
 use super::symbol_resolver::{SymbolResolver, SymbolType};
-use super::typedef::{Call, Type, TypeEnum, Unifier};
+use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
 use itertools::izip;
-use rustpython_parser::ast::{self, fold};
+use rustpython_parser::ast::{
+    self,
+    fold::{self, Fold},
+    Arguments, Expr, ExprKind, Located, Location,
+};
 
 pub struct PrimitiveStore {
     int32: Type,
@@ -21,7 +25,7 @@ pub struct PrimitiveStore {
 pub struct Inferencer<'a> {
     resolver: &'a mut Box<dyn SymbolResolver>,
     unifier: &'a mut Unifier,
-    variable_mapping: &'a mut HashMap<String, Type>,
+    variable_mapping: HashMap<String, Type>,
     calls: &'a mut Vec<Rc<Call>>,
     primitives: &'a PrimitiveStore,
 }
@@ -35,10 +39,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
     }
 
     fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
-        let expr = match &node.node {
-            ast::ExprKind::Call { .. } => unimplemented!(),
-            ast::ExprKind::Lambda { .. } => unimplemented!(),
-            ast::ExprKind::ListComp { .. } => unimplemented!(),
+        let expr = match node.node {
+            ast::ExprKind::Call {
+                func,
+                args,
+                keywords,
+            } => unimplemented!(),
+            ast::ExprKind::Lambda { args, body } => {
+                self.fold_lambda(node.location, *args, *body)?
+            }
+            ast::ExprKind::ListComp { elt, generators } => unimplemented!(),
             _ => fold::fold_expr(self, node)?,
         };
         let custom = match &expr.node {
@@ -59,11 +69,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
                 ops,
                 comparators,
             } => Some(self.infer_compare(left, ops, comparators)?),
-            ast::ExprKind::Call {
-                func,
-                args,
-                keywords,
-            } => unimplemented!(),
+            ast::ExprKind::Call { .. } => expr.custom,
             ast::ExprKind::Subscript {
                 value,
                 slice,
@@ -117,6 +123,69 @@ impl<'a> Inferencer<'a> {
         Ok(ret)
     }
 
+    fn fold_lambda(
+        &mut self,
+        location: Location,
+        args: Arguments,
+        body: ast::Expr<()>,
+    ) -> Result<ast::Expr<Option<Type>>, String> {
+        if !args.posonlyargs.is_empty()
+            || args.vararg.is_some()
+            || !args.kwonlyargs.is_empty()
+            || args.kwarg.is_some()
+            || !args.defaults.is_empty()
+        {
+            // actually I'm not sure whether programs violating this is a valid python program.
+            return Err(
+                "We only support positional or keyword arguments without defaults for lambdas."
+                    .to_string(),
+            );
+        }
+
+        let fn_args: Vec<_> = args
+            .args
+            .iter()
+            .map(|v| (v.node.arg.clone(), self.unifier.get_fresh_var().0))
+            .collect();
+        let mut variable_mapping = self.variable_mapping.clone();
+        variable_mapping.extend(fn_args.iter().cloned());
+        let ret = self.unifier.get_fresh_var().0;
+        let mut new_context = Inferencer {
+            resolver: self.resolver,
+            unifier: self.unifier,
+            variable_mapping,
+            calls: self.calls,
+            primitives: self.primitives,
+        };
+        let fun = FunSignature {
+            args: fn_args
+                .iter()
+                .map(|(k, ty)| FuncArg {
+                    name: k.clone(),
+                    ty: *ty,
+                    is_optional: false,
+                })
+                .collect(),
+            ret,
+            vars: Default::default(),
+        };
+        let body = new_context.fold_expr(body)?;
+        new_context.unifier.unify(fun.ret, body.custom.unwrap())?;
+        let mut args = new_context.fold_arguments(args)?;
+        for (arg, (name, ty)) in args.args.iter_mut().zip(fn_args.iter()) {
+            assert_eq!(&arg.node.arg, name);
+            arg.custom = Some(*ty);
+        }
+        Ok(Located {
+            location,
+            node: ExprKind::Lambda {
+                args: args.into(),
+                body: body.into(),
+            },
+            custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun))),
+        })
+    }
+
     fn infer_identifier(&mut self, id: &str) -> InferenceResult {
         if let Some(ty) = self.variable_mapping.get(id) {
             Ok(*ty)