lambda fold

This commit is contained in:
pca006132 2021-07-20 11:34:32 +08:00
parent 016166de46
commit 22455e43ac

View File

@ -6,9 +6,13 @@ use std::rc::Rc;
use super::magic_methods::*; use super::magic_methods::*;
use super::symbol_resolver::{SymbolResolver, SymbolType}; use super::symbol_resolver::{SymbolResolver, SymbolType};
use super::typedef::{Call, Type, TypeEnum, Unifier}; use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
use itertools::izip; use itertools::izip;
use rustpython_parser::ast::{self, fold}; use rustpython_parser::ast::{
self,
fold::{self, Fold},
Arguments, Expr, ExprKind, Located, Location,
};
pub struct PrimitiveStore { pub struct PrimitiveStore {
int32: Type, int32: Type,
@ -21,7 +25,7 @@ pub struct PrimitiveStore {
pub struct Inferencer<'a> { pub struct Inferencer<'a> {
resolver: &'a mut Box<dyn SymbolResolver>, resolver: &'a mut Box<dyn SymbolResolver>,
unifier: &'a mut Unifier, unifier: &'a mut Unifier,
variable_mapping: &'a mut HashMap<String, Type>, variable_mapping: HashMap<String, Type>,
calls: &'a mut Vec<Rc<Call>>, calls: &'a mut Vec<Rc<Call>>,
primitives: &'a PrimitiveStore, primitives: &'a PrimitiveStore,
} }
@ -35,10 +39,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
} }
fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> { fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
let expr = match &node.node { let expr = match node.node {
ast::ExprKind::Call { .. } => unimplemented!(), ast::ExprKind::Call {
ast::ExprKind::Lambda { .. } => unimplemented!(), func,
ast::ExprKind::ListComp { .. } => unimplemented!(), args,
keywords,
} => unimplemented!(),
ast::ExprKind::Lambda { args, body } => {
self.fold_lambda(node.location, *args, *body)?
}
ast::ExprKind::ListComp { elt, generators } => unimplemented!(),
_ => fold::fold_expr(self, node)?, _ => fold::fold_expr(self, node)?,
}; };
let custom = match &expr.node { let custom = match &expr.node {
@ -59,11 +69,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
ops, ops,
comparators, comparators,
} => Some(self.infer_compare(left, ops, comparators)?), } => Some(self.infer_compare(left, ops, comparators)?),
ast::ExprKind::Call { ast::ExprKind::Call { .. } => expr.custom,
func,
args,
keywords,
} => unimplemented!(),
ast::ExprKind::Subscript { ast::ExprKind::Subscript {
value, value,
slice, slice,
@ -117,6 +123,69 @@ impl<'a> Inferencer<'a> {
Ok(ret) Ok(ret)
} }
fn fold_lambda(
&mut self,
location: Location,
args: Arguments,
body: ast::Expr<()>,
) -> Result<ast::Expr<Option<Type>>, String> {
if !args.posonlyargs.is_empty()
|| args.vararg.is_some()
|| !args.kwonlyargs.is_empty()
|| args.kwarg.is_some()
|| !args.defaults.is_empty()
{
// actually I'm not sure whether programs violating this is a valid python program.
return Err(
"We only support positional or keyword arguments without defaults for lambdas."
.to_string(),
);
}
let fn_args: Vec<_> = args
.args
.iter()
.map(|v| (v.node.arg.clone(), self.unifier.get_fresh_var().0))
.collect();
let mut variable_mapping = self.variable_mapping.clone();
variable_mapping.extend(fn_args.iter().cloned());
let ret = self.unifier.get_fresh_var().0;
let mut new_context = Inferencer {
resolver: self.resolver,
unifier: self.unifier,
variable_mapping,
calls: self.calls,
primitives: self.primitives,
};
let fun = FunSignature {
args: fn_args
.iter()
.map(|(k, ty)| FuncArg {
name: k.clone(),
ty: *ty,
is_optional: false,
})
.collect(),
ret,
vars: Default::default(),
};
let body = new_context.fold_expr(body)?;
new_context.unifier.unify(fun.ret, body.custom.unwrap())?;
let mut args = new_context.fold_arguments(args)?;
for (arg, (name, ty)) in args.args.iter_mut().zip(fn_args.iter()) {
assert_eq!(&arg.node.arg, name);
arg.custom = Some(*ty);
}
Ok(Located {
location,
node: ExprKind::Lambda {
args: args.into(),
body: body.into(),
},
custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun))),
})
}
fn infer_identifier(&mut self, id: &str) -> InferenceResult { fn infer_identifier(&mut self, id: &str) -> InferenceResult {
if let Some(ty) = self.variable_mapping.get(id) { if let Some(ty) = self.variable_mapping.get(id) {
Ok(*ty) Ok(*ty)