lambda fold

This commit is contained in:
pca006132 2021-07-20 11:34:32 +08:00
parent 016166de46
commit 22455e43ac

View File

@ -6,9 +6,13 @@ use std::rc::Rc;
use super::magic_methods::*;
use super::symbol_resolver::{SymbolResolver, SymbolType};
use super::typedef::{Call, Type, TypeEnum, Unifier};
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
use itertools::izip;
use rustpython_parser::ast::{self, fold};
use rustpython_parser::ast::{
self,
fold::{self, Fold},
Arguments, Expr, ExprKind, Located, Location,
};
pub struct PrimitiveStore {
int32: Type,
@ -21,7 +25,7 @@ pub struct PrimitiveStore {
pub struct Inferencer<'a> {
resolver: &'a mut Box<dyn SymbolResolver>,
unifier: &'a mut Unifier,
variable_mapping: &'a mut HashMap<String, Type>,
variable_mapping: HashMap<String, Type>,
calls: &'a mut Vec<Rc<Call>>,
primitives: &'a PrimitiveStore,
}
@ -35,10 +39,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
}
fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
let expr = match &node.node {
ast::ExprKind::Call { .. } => unimplemented!(),
ast::ExprKind::Lambda { .. } => unimplemented!(),
ast::ExprKind::ListComp { .. } => unimplemented!(),
let expr = match node.node {
ast::ExprKind::Call {
func,
args,
keywords,
} => unimplemented!(),
ast::ExprKind::Lambda { args, body } => {
self.fold_lambda(node.location, *args, *body)?
}
ast::ExprKind::ListComp { elt, generators } => unimplemented!(),
_ => fold::fold_expr(self, node)?,
};
let custom = match &expr.node {
@ -59,11 +69,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
ops,
comparators,
} => Some(self.infer_compare(left, ops, comparators)?),
ast::ExprKind::Call {
func,
args,
keywords,
} => unimplemented!(),
ast::ExprKind::Call { .. } => expr.custom,
ast::ExprKind::Subscript {
value,
slice,
@ -117,6 +123,69 @@ impl<'a> Inferencer<'a> {
Ok(ret)
}
fn fold_lambda(
&mut self,
location: Location,
args: Arguments,
body: ast::Expr<()>,
) -> Result<ast::Expr<Option<Type>>, String> {
if !args.posonlyargs.is_empty()
|| args.vararg.is_some()
|| !args.kwonlyargs.is_empty()
|| args.kwarg.is_some()
|| !args.defaults.is_empty()
{
// actually I'm not sure whether programs violating this is a valid python program.
return Err(
"We only support positional or keyword arguments without defaults for lambdas."
.to_string(),
);
}
let fn_args: Vec<_> = args
.args
.iter()
.map(|v| (v.node.arg.clone(), self.unifier.get_fresh_var().0))
.collect();
let mut variable_mapping = self.variable_mapping.clone();
variable_mapping.extend(fn_args.iter().cloned());
let ret = self.unifier.get_fresh_var().0;
let mut new_context = Inferencer {
resolver: self.resolver,
unifier: self.unifier,
variable_mapping,
calls: self.calls,
primitives: self.primitives,
};
let fun = FunSignature {
args: fn_args
.iter()
.map(|(k, ty)| FuncArg {
name: k.clone(),
ty: *ty,
is_optional: false,
})
.collect(),
ret,
vars: Default::default(),
};
let body = new_context.fold_expr(body)?;
new_context.unifier.unify(fun.ret, body.custom.unwrap())?;
let mut args = new_context.fold_arguments(args)?;
for (arg, (name, ty)) in args.args.iter_mut().zip(fn_args.iter()) {
assert_eq!(&arg.node.arg, name);
arg.custom = Some(*ty);
}
Ok(Located {
location,
node: ExprKind::Lambda {
args: args.into(),
body: body.into(),
},
custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun))),
})
}
fn infer_identifier(&mut self, id: &str) -> InferenceResult {
if let Some(ty) = self.variable_mapping.get(id) {
Ok(*ty)