lambda fold
This commit is contained in:
parent
016166de46
commit
22455e43ac
|
@ -6,9 +6,13 @@ use std::rc::Rc;
|
|||
|
||||
use super::magic_methods::*;
|
||||
use super::symbol_resolver::{SymbolResolver, SymbolType};
|
||||
use super::typedef::{Call, Type, TypeEnum, Unifier};
|
||||
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
|
||||
use itertools::izip;
|
||||
use rustpython_parser::ast::{self, fold};
|
||||
use rustpython_parser::ast::{
|
||||
self,
|
||||
fold::{self, Fold},
|
||||
Arguments, Expr, ExprKind, Located, Location,
|
||||
};
|
||||
|
||||
pub struct PrimitiveStore {
|
||||
int32: Type,
|
||||
|
@ -21,7 +25,7 @@ pub struct PrimitiveStore {
|
|||
pub struct Inferencer<'a> {
|
||||
resolver: &'a mut Box<dyn SymbolResolver>,
|
||||
unifier: &'a mut Unifier,
|
||||
variable_mapping: &'a mut HashMap<String, Type>,
|
||||
variable_mapping: HashMap<String, Type>,
|
||||
calls: &'a mut Vec<Rc<Call>>,
|
||||
primitives: &'a PrimitiveStore,
|
||||
}
|
||||
|
@ -35,10 +39,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
}
|
||||
|
||||
fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
|
||||
let expr = match &node.node {
|
||||
ast::ExprKind::Call { .. } => unimplemented!(),
|
||||
ast::ExprKind::Lambda { .. } => unimplemented!(),
|
||||
ast::ExprKind::ListComp { .. } => unimplemented!(),
|
||||
let expr = match node.node {
|
||||
ast::ExprKind::Call {
|
||||
func,
|
||||
args,
|
||||
keywords,
|
||||
} => unimplemented!(),
|
||||
ast::ExprKind::Lambda { args, body } => {
|
||||
self.fold_lambda(node.location, *args, *body)?
|
||||
}
|
||||
ast::ExprKind::ListComp { elt, generators } => unimplemented!(),
|
||||
_ => fold::fold_expr(self, node)?,
|
||||
};
|
||||
let custom = match &expr.node {
|
||||
|
@ -59,11 +69,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
ops,
|
||||
comparators,
|
||||
} => Some(self.infer_compare(left, ops, comparators)?),
|
||||
ast::ExprKind::Call {
|
||||
func,
|
||||
args,
|
||||
keywords,
|
||||
} => unimplemented!(),
|
||||
ast::ExprKind::Call { .. } => expr.custom,
|
||||
ast::ExprKind::Subscript {
|
||||
value,
|
||||
slice,
|
||||
|
@ -117,6 +123,69 @@ impl<'a> Inferencer<'a> {
|
|||
Ok(ret)
|
||||
}
|
||||
|
||||
fn fold_lambda(
|
||||
&mut self,
|
||||
location: Location,
|
||||
args: Arguments,
|
||||
body: ast::Expr<()>,
|
||||
) -> Result<ast::Expr<Option<Type>>, String> {
|
||||
if !args.posonlyargs.is_empty()
|
||||
|| args.vararg.is_some()
|
||||
|| !args.kwonlyargs.is_empty()
|
||||
|| args.kwarg.is_some()
|
||||
|| !args.defaults.is_empty()
|
||||
{
|
||||
// actually I'm not sure whether programs violating this is a valid python program.
|
||||
return Err(
|
||||
"We only support positional or keyword arguments without defaults for lambdas."
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let fn_args: Vec<_> = args
|
||||
.args
|
||||
.iter()
|
||||
.map(|v| (v.node.arg.clone(), self.unifier.get_fresh_var().0))
|
||||
.collect();
|
||||
let mut variable_mapping = self.variable_mapping.clone();
|
||||
variable_mapping.extend(fn_args.iter().cloned());
|
||||
let ret = self.unifier.get_fresh_var().0;
|
||||
let mut new_context = Inferencer {
|
||||
resolver: self.resolver,
|
||||
unifier: self.unifier,
|
||||
variable_mapping,
|
||||
calls: self.calls,
|
||||
primitives: self.primitives,
|
||||
};
|
||||
let fun = FunSignature {
|
||||
args: fn_args
|
||||
.iter()
|
||||
.map(|(k, ty)| FuncArg {
|
||||
name: k.clone(),
|
||||
ty: *ty,
|
||||
is_optional: false,
|
||||
})
|
||||
.collect(),
|
||||
ret,
|
||||
vars: Default::default(),
|
||||
};
|
||||
let body = new_context.fold_expr(body)?;
|
||||
new_context.unifier.unify(fun.ret, body.custom.unwrap())?;
|
||||
let mut args = new_context.fold_arguments(args)?;
|
||||
for (arg, (name, ty)) in args.args.iter_mut().zip(fn_args.iter()) {
|
||||
assert_eq!(&arg.node.arg, name);
|
||||
arg.custom = Some(*ty);
|
||||
}
|
||||
Ok(Located {
|
||||
location,
|
||||
node: ExprKind::Lambda {
|
||||
args: args.into(),
|
||||
body: body.into(),
|
||||
},
|
||||
custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun))),
|
||||
})
|
||||
}
|
||||
|
||||
fn infer_identifier(&mut self, id: &str) -> InferenceResult {
|
||||
if let Some(ty) = self.variable_mapping.get(id) {
|
||||
Ok(*ty)
|
||||
|
|
Loading…
Reference in New Issue