lambda fold
This commit is contained in:
parent
016166de46
commit
22455e43ac
|
@ -6,9 +6,13 @@ use std::rc::Rc;
|
||||||
|
|
||||||
use super::magic_methods::*;
|
use super::magic_methods::*;
|
||||||
use super::symbol_resolver::{SymbolResolver, SymbolType};
|
use super::symbol_resolver::{SymbolResolver, SymbolType};
|
||||||
use super::typedef::{Call, Type, TypeEnum, Unifier};
|
use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier};
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use rustpython_parser::ast::{self, fold};
|
use rustpython_parser::ast::{
|
||||||
|
self,
|
||||||
|
fold::{self, Fold},
|
||||||
|
Arguments, Expr, ExprKind, Located, Location,
|
||||||
|
};
|
||||||
|
|
||||||
pub struct PrimitiveStore {
|
pub struct PrimitiveStore {
|
||||||
int32: Type,
|
int32: Type,
|
||||||
|
@ -21,7 +25,7 @@ pub struct PrimitiveStore {
|
||||||
pub struct Inferencer<'a> {
|
pub struct Inferencer<'a> {
|
||||||
resolver: &'a mut Box<dyn SymbolResolver>,
|
resolver: &'a mut Box<dyn SymbolResolver>,
|
||||||
unifier: &'a mut Unifier,
|
unifier: &'a mut Unifier,
|
||||||
variable_mapping: &'a mut HashMap<String, Type>,
|
variable_mapping: HashMap<String, Type>,
|
||||||
calls: &'a mut Vec<Rc<Call>>,
|
calls: &'a mut Vec<Rc<Call>>,
|
||||||
primitives: &'a PrimitiveStore,
|
primitives: &'a PrimitiveStore,
|
||||||
}
|
}
|
||||||
|
@ -35,10 +39,16 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
|
fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
|
||||||
let expr = match &node.node {
|
let expr = match node.node {
|
||||||
ast::ExprKind::Call { .. } => unimplemented!(),
|
ast::ExprKind::Call {
|
||||||
ast::ExprKind::Lambda { .. } => unimplemented!(),
|
func,
|
||||||
ast::ExprKind::ListComp { .. } => unimplemented!(),
|
args,
|
||||||
|
keywords,
|
||||||
|
} => unimplemented!(),
|
||||||
|
ast::ExprKind::Lambda { args, body } => {
|
||||||
|
self.fold_lambda(node.location, *args, *body)?
|
||||||
|
}
|
||||||
|
ast::ExprKind::ListComp { elt, generators } => unimplemented!(),
|
||||||
_ => fold::fold_expr(self, node)?,
|
_ => fold::fold_expr(self, node)?,
|
||||||
};
|
};
|
||||||
let custom = match &expr.node {
|
let custom = match &expr.node {
|
||||||
|
@ -59,11 +69,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
ops,
|
ops,
|
||||||
comparators,
|
comparators,
|
||||||
} => Some(self.infer_compare(left, ops, comparators)?),
|
} => Some(self.infer_compare(left, ops, comparators)?),
|
||||||
ast::ExprKind::Call {
|
ast::ExprKind::Call { .. } => expr.custom,
|
||||||
func,
|
|
||||||
args,
|
|
||||||
keywords,
|
|
||||||
} => unimplemented!(),
|
|
||||||
ast::ExprKind::Subscript {
|
ast::ExprKind::Subscript {
|
||||||
value,
|
value,
|
||||||
slice,
|
slice,
|
||||||
|
@ -117,6 +123,69 @@ impl<'a> Inferencer<'a> {
|
||||||
Ok(ret)
|
Ok(ret)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn fold_lambda(
|
||||||
|
&mut self,
|
||||||
|
location: Location,
|
||||||
|
args: Arguments,
|
||||||
|
body: ast::Expr<()>,
|
||||||
|
) -> Result<ast::Expr<Option<Type>>, String> {
|
||||||
|
if !args.posonlyargs.is_empty()
|
||||||
|
|| args.vararg.is_some()
|
||||||
|
|| !args.kwonlyargs.is_empty()
|
||||||
|
|| args.kwarg.is_some()
|
||||||
|
|| !args.defaults.is_empty()
|
||||||
|
{
|
||||||
|
// actually I'm not sure whether programs violating this is a valid python program.
|
||||||
|
return Err(
|
||||||
|
"We only support positional or keyword arguments without defaults for lambdas."
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let fn_args: Vec<_> = args
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.map(|v| (v.node.arg.clone(), self.unifier.get_fresh_var().0))
|
||||||
|
.collect();
|
||||||
|
let mut variable_mapping = self.variable_mapping.clone();
|
||||||
|
variable_mapping.extend(fn_args.iter().cloned());
|
||||||
|
let ret = self.unifier.get_fresh_var().0;
|
||||||
|
let mut new_context = Inferencer {
|
||||||
|
resolver: self.resolver,
|
||||||
|
unifier: self.unifier,
|
||||||
|
variable_mapping,
|
||||||
|
calls: self.calls,
|
||||||
|
primitives: self.primitives,
|
||||||
|
};
|
||||||
|
let fun = FunSignature {
|
||||||
|
args: fn_args
|
||||||
|
.iter()
|
||||||
|
.map(|(k, ty)| FuncArg {
|
||||||
|
name: k.clone(),
|
||||||
|
ty: *ty,
|
||||||
|
is_optional: false,
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
ret,
|
||||||
|
vars: Default::default(),
|
||||||
|
};
|
||||||
|
let body = new_context.fold_expr(body)?;
|
||||||
|
new_context.unifier.unify(fun.ret, body.custom.unwrap())?;
|
||||||
|
let mut args = new_context.fold_arguments(args)?;
|
||||||
|
for (arg, (name, ty)) in args.args.iter_mut().zip(fn_args.iter()) {
|
||||||
|
assert_eq!(&arg.node.arg, name);
|
||||||
|
arg.custom = Some(*ty);
|
||||||
|
}
|
||||||
|
Ok(Located {
|
||||||
|
location,
|
||||||
|
node: ExprKind::Lambda {
|
||||||
|
args: args.into(),
|
||||||
|
body: body.into(),
|
||||||
|
},
|
||||||
|
custom: Some(self.unifier.add_ty(TypeEnum::TFunc(fun))),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn infer_identifier(&mut self, id: &str) -> InferenceResult {
|
fn infer_identifier(&mut self, id: &str) -> InferenceResult {
|
||||||
if let Some(ty) = self.variable_mapping.get(id) {
|
if let Some(ty) = self.variable_mapping.get(id) {
|
||||||
Ok(*ty)
|
Ok(*ty)
|
||||||
|
|
Loading…
Reference in New Issue