hm-inference #6
|
@ -1,20 +1,41 @@
|
||||||
use super::type_inferencer::Inferencer;
|
use super::type_inferencer::Inferencer;
|
||||||
use super::typedef::Type;
|
use super::typedef::Type;
|
||||||
use rustpython_parser::ast::{self, Expr, ExprKind, StmtKind};
|
use rustpython_parser::ast::{self, Expr, ExprKind, Stmt, StmtKind};
|
||||||
use std::iter::once;
|
use std::iter::once;
|
||||||
|
|
||||||
impl<'a> Inferencer<'a> {
|
impl<'a> Inferencer<'a> {
|
||||||
|
fn check_pattern(
|
||||||
|
&mut self,
|
||||||
|
pattern: &Expr<Option<Type>>,
|
||||||
|
defined_identifiers: &mut Vec<String>,
|
||||||
|
) {
|
||||||
|
match &pattern.node {
|
||||||
|
ExprKind::Name { id, .. } => {
|
||||||
|
if !defined_identifiers.contains(id) {
|
||||||
|
defined_identifiers.push(id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::Tuple { elts, .. } => {
|
||||||
|
for elt in elts.iter() {
|
||||||
|
self.check_pattern(elt, defined_identifiers);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => unimplemented!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn check_expr(
|
fn check_expr(
|
||||||
&mut self,
|
&mut self,
|
||||||
expr: &Expr<Option<Type>>,
|
expr: &Expr<Option<Type>>,
|
||||||
defined_identifiers: &[String],
|
defined_identifiers: &[String],
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
|
// there are some cases where the custom field is None
|
||||||
if let Some(ty) = &expr.custom {
|
if let Some(ty) = &expr.custom {
|
||||||
let ty = self.unifier.get_ty(*ty);
|
let ty = self.unifier.get_ty(*ty);
|
||||||
let ty = ty.as_ref().borrow();
|
let ty = ty.as_ref().borrow();
|
||||||
if ty.is_concrete() {
|
if !ty.is_concrete() {
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"expected concrete type at {:?} but got {}",
|
"expected concrete type at {} but got {}",
|
||||||
expr.location,
|
expr.location,
|
||||||
ty.get_type_name()
|
ty.get_type_name()
|
||||||
));
|
));
|
||||||
|
@ -23,7 +44,7 @@ impl<'a> Inferencer<'a> {
|
||||||
match &expr.node {
|
match &expr.node {
|
||||||
ExprKind::Name { id, .. } => {
|
ExprKind::Name { id, .. } => {
|
||||||
if !defined_identifiers.contains(id) {
|
if !defined_identifiers.contains(id) {
|
||||||
return Err(format!("unknown identifier {} (use before def?)", id));
|
return Err(format!("unknown identifier {} (use before def?) at {}", id, expr.location));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExprKind::List { elts, .. }
|
ExprKind::List { elts, .. }
|
||||||
|
@ -34,14 +55,14 @@ impl<'a> Inferencer<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExprKind::Attribute { value, .. } => {
|
ExprKind::Attribute { value, .. } => {
|
||||||
self.check_expr(value.as_ref(), defined_identifiers)?;
|
self.check_expr(value, defined_identifiers)?;
|
||||||
}
|
}
|
||||||
ExprKind::BinOp { left, right, .. } => {
|
ExprKind::BinOp { left, right, .. } => {
|
||||||
self.check_expr(left.as_ref(), defined_identifiers)?;
|
self.check_expr(left, defined_identifiers)?;
|
||||||
self.check_expr(right.as_ref(), defined_identifiers)?;
|
self.check_expr(right, defined_identifiers)?;
|
||||||
}
|
}
|
||||||
ExprKind::UnaryOp { operand, .. } => {
|
ExprKind::UnaryOp { operand, .. } => {
|
||||||
self.check_expr(operand.as_ref(), defined_identifiers)?;
|
self.check_expr(operand, defined_identifiers)?;
|
||||||
}
|
}
|
||||||
ExprKind::Compare {
|
ExprKind::Compare {
|
||||||
left, comparators, ..
|
left, comparators, ..
|
||||||
|
@ -51,13 +72,13 @@ impl<'a> Inferencer<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExprKind::Subscript { value, slice, .. } => {
|
ExprKind::Subscript { value, slice, .. } => {
|
||||||
self.check_expr(value.as_ref(), defined_identifiers)?;
|
self.check_expr(value, defined_identifiers)?;
|
||||||
self.check_expr(slice.as_ref(), defined_identifiers)?;
|
self.check_expr(slice, defined_identifiers)?;
|
||||||
}
|
}
|
||||||
ExprKind::IfExp { test, body, orelse } => {
|
ExprKind::IfExp { test, body, orelse } => {
|
||||||
self.check_expr(test.as_ref(), defined_identifiers)?;
|
self.check_expr(test, defined_identifiers)?;
|
||||||
self.check_expr(body.as_ref(), defined_identifiers)?;
|
self.check_expr(body, defined_identifiers)?;
|
||||||
self.check_expr(orelse.as_ref(), defined_identifiers)?;
|
self.check_expr(orelse, defined_identifiers)?;
|
||||||
}
|
}
|
||||||
ExprKind::Slice { lower, upper, step } => {
|
ExprKind::Slice { lower, upper, step } => {
|
||||||
for elt in [lower.as_ref(), upper.as_ref(), step.as_ref()]
|
for elt in [lower.as_ref(), upper.as_ref(), step.as_ref()]
|
||||||
|
@ -67,10 +88,132 @@ impl<'a> Inferencer<'a> {
|
||||||
self.check_expr(elt, defined_identifiers)?;
|
self.check_expr(elt, defined_identifiers)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExprKind::ListComp { .. } => unimplemented!(),
|
ExprKind::Lambda { args, body } => {
|
||||||
ExprKind::Lambda { .. } => unimplemented!(),
|
let mut defined_identifiers = defined_identifiers.to_vec();
|
||||||
_ => {}
|
for arg in args.args.iter() {
|
||||||
|
if !defined_identifiers.contains(&arg.node.arg) {
|
||||||
|
defined_identifiers.push(arg.node.arg.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.check_expr(body, &defined_identifiers)?;
|
||||||
|
}
|
||||||
|
ExprKind::ListComp {
|
||||||
|
elt, generators, ..
|
||||||
|
} => {
|
||||||
|
// in our type inference stage, we already make sure that there is only 1 generator
|
||||||
|
let ast::Comprehension {
|
||||||
|
target, iter, ifs, ..
|
||||||
|
} = &generators[0];
|
||||||
|
self.check_expr(iter, defined_identifiers)?;
|
||||||
|
let mut defined_identifiers = defined_identifiers.to_vec();
|
||||||
|
self.check_pattern(target, &mut defined_identifiers);
|
||||||
|
for term in once(elt.as_ref()).chain(ifs.iter()) {
|
||||||
|
self.check_expr(term, &defined_identifiers)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::Call {
|
||||||
|
func,
|
||||||
|
args,
|
||||||
|
keywords,
|
||||||
|
} => {
|
||||||
|
for expr in once(func.as_ref())
|
||||||
|
.chain(args.iter())
|
||||||
|
.chain(keywords.iter().map(|v| v.node.value.as_ref()))
|
||||||
|
{
|
||||||
|
self.check_expr(expr, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ExprKind::Constant { .. } => {}
|
||||||
|
_ => {
|
||||||
|
println!("{:?}", expr.node);
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn check_stmt(
|
||||||
|
&mut self,
|
||||||
|
stmt: &Stmt<Option<Type>>,
|
||||||
|
defined_identifiers: &mut Vec<String>,
|
||||||
|
) -> Result<bool, String> {
|
||||||
|
match &stmt.node {
|
||||||
|
StmtKind::For {
|
||||||
|
target,
|
||||||
|
iter,
|
||||||
|
body,
|
||||||
|
orelse,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
self.check_expr(iter, defined_identifiers)?;
|
||||||
|
for stmt in orelse.iter() {
|
||||||
|
self.check_stmt(stmt, defined_identifiers)?;
|
||||||
|
}
|
||||||
|
let mut defined_identifiers = defined_identifiers.clone();
|
||||||
|
self.check_pattern(target, &mut defined_identifiers);
|
||||||
|
for stmt in body.iter() {
|
||||||
|
self.check_stmt(stmt, &mut defined_identifiers)?;
|
||||||
|
}
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
StmtKind::If { test, body, orelse } => {
|
||||||
|
self.check_expr(test, defined_identifiers)?;
|
||||||
|
let mut body_identifiers = defined_identifiers.clone();
|
||||||
|
let mut orelse_identifiers = defined_identifiers.clone();
|
||||||
|
let body_returned = self.check_block(body, &mut body_identifiers)?;
|
||||||
|
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
|
||||||
|
|
||||||
|
for ident in body_identifiers.iter() {
|
||||||
|
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
|
||||||
|
defined_identifiers.push(ident.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(body_returned && orelse_returned)
|
||||||
|
}
|
||||||
|
StmtKind::While { test, body, orelse } => {
|
||||||
|
self.check_expr(test, defined_identifiers)?;
|
||||||
|
let mut defined_identifiers = defined_identifiers.clone();
|
||||||
|
self.check_block(body, &mut defined_identifiers)?;
|
||||||
|
self.check_block(orelse, &mut defined_identifiers)?;
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
StmtKind::Expr { value } => {
|
||||||
|
self.check_expr(value, defined_identifiers)?;
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
StmtKind::Assign { targets, value, .. } => {
|
||||||
|
self.check_expr(value, defined_identifiers)?;
|
||||||
|
for target in targets {
|
||||||
|
self.check_pattern(target, defined_identifiers);
|
||||||
|
}
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
StmtKind::AnnAssign { target, value, .. } => {
|
||||||
|
if let Some(value) = value {
|
||||||
|
self.check_expr(value, defined_identifiers)?;
|
||||||
|
self.check_pattern(target, defined_identifiers);
|
||||||
|
}
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
// break, return, raise, etc.
|
||||||
|
_ => Ok(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_block(
|
||||||
|
&mut self,
|
||||||
|
block: &[Stmt<Option<Type>>],
|
||||||
|
defined_identifiers: &mut Vec<String>,
|
||||||
|
) -> Result<bool, String> {
|
||||||
|
let mut ret = false;
|
||||||
|
for stmt in block {
|
||||||
|
if ret {
|
||||||
|
return Err(format!("dead code at {:?}", stmt.location));
|
||||||
|
}
|
||||||
|
if self.check_stmt(stmt, defined_identifiers)? {
|
||||||
|
ret = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,6 +107,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
||||||
|
ast::StmtKind::Break | ast::StmtKind::Continue => {}
|
||||||
ast::StmtKind::Return { value } => match (value, self.return_type) {
|
ast::StmtKind::Return { value } => match (value, self.return_type) {
|
||||||
(Some(v), Some(v1)) => {
|
(Some(v), Some(v1)) => {
|
||||||
self.unifier.unify(v.custom.unwrap(), v1)?;
|
self.unifier.unify(v.custom.unwrap(), v1)?;
|
||||||
|
@ -130,12 +131,14 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
func,
|
func,
|
||||||
args,
|
args,
|
||||||
keywords,
|
keywords,
|
||||||
} => self.fold_call(node.location, *func, args, keywords)?,
|
} => {
|
||||||
|
return self.fold_call(node.location, *func, args, keywords);
|
||||||
|
}
|
||||||
ast::ExprKind::Lambda { args, body } => {
|
ast::ExprKind::Lambda { args, body } => {
|
||||||
self.fold_lambda(node.location, *args, *body)?
|
return self.fold_lambda(node.location, *args, *body);
|
||||||
}
|
}
|
||||||
ast::ExprKind::ListComp { elt, generators } => {
|
ast::ExprKind::ListComp { elt, generators } => {
|
||||||
self.fold_listcomp(node.location, *elt, generators)?
|
return self.fold_listcomp(node.location, *elt, generators);
|
||||||
}
|
}
|
||||||
_ => fold::fold_expr(self, node)?,
|
_ => fold::fold_expr(self, node)?,
|
||||||
};
|
};
|
||||||
|
|
|
@ -8,12 +8,12 @@ use rustpython_parser::parser::parse_program;
|
||||||
use test_case::test_case;
|
use test_case::test_case;
|
||||||
|
|
||||||
struct Resolver {
|
struct Resolver {
|
||||||
type_mapping: HashMap<String, Type>,
|
identifier_mapping: HashMap<String, Type>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SymbolResolver for Resolver {
|
impl SymbolResolver for Resolver {
|
||||||
fn get_symbol_type(&mut self, str: &str) -> Option<Type> {
|
fn get_symbol_type(&mut self, str: &str) -> Option<Type> {
|
||||||
self.type_mapping.get(str).cloned()
|
self.identifier_mapping.get(str).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_type_name(&mut self, _: &ast::Expr<()>) -> Option<Type> {
|
fn parse_type_name(&mut self, _: &ast::Expr<()>) -> Option<Type> {
|
||||||
|
@ -35,12 +35,13 @@ struct TestEnvironment {
|
||||||
pub calls: Vec<Rc<Call>>,
|
pub calls: Vec<Rc<Call>>,
|
||||||
pub primitives: PrimitiveStore,
|
pub primitives: PrimitiveStore,
|
||||||
pub id_to_name: HashMap<usize, String>,
|
pub id_to_name: HashMap<usize, String>,
|
||||||
|
pub identifier_mapping: HashMap<String, Type>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TestEnvironment {
|
impl TestEnvironment {
|
||||||
fn new() -> TestEnvironment {
|
fn new() -> TestEnvironment {
|
||||||
let mut unifier = Unifier::new();
|
let mut unifier = Unifier::new();
|
||||||
let mut type_mapping = HashMap::new();
|
let mut identifier_mapping = HashMap::new();
|
||||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: 0,
|
obj_id: 0,
|
||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
|
@ -66,7 +67,7 @@ impl TestEnvironment {
|
||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: HashMap::new(),
|
params: HashMap::new(),
|
||||||
});
|
});
|
||||||
type_mapping.insert("None".into(), none);
|
identifier_mapping.insert("None".into(), none);
|
||||||
|
|
||||||
let primitives = PrimitiveStore {
|
let primitives = PrimitiveStore {
|
||||||
int32,
|
int32,
|
||||||
|
@ -84,7 +85,7 @@ impl TestEnvironment {
|
||||||
params: [(id, v0)].iter().cloned().collect(),
|
params: [(id, v0)].iter().cloned().collect(),
|
||||||
});
|
});
|
||||||
|
|
||||||
type_mapping.insert(
|
identifier_mapping.insert(
|
||||||
"Foo".into(),
|
"Foo".into(),
|
||||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![],
|
args: vec![],
|
||||||
|
@ -105,13 +106,14 @@ impl TestEnvironment {
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let resolver = Box::new(Resolver { type_mapping }) as Box<dyn SymbolResolver>;
|
let resolver = Box::new(Resolver { identifier_mapping: identifier_mapping.clone() }) as Box<dyn SymbolResolver>;
|
||||||
|
|
||||||
TestEnvironment {
|
TestEnvironment {
|
||||||
unifier,
|
unifier,
|
||||||
resolver,
|
resolver,
|
||||||
primitives,
|
primitives,
|
||||||
id_to_name,
|
id_to_name,
|
||||||
|
identifier_mapping,
|
||||||
calls: Vec::new(),
|
calls: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -168,15 +170,20 @@ impl TestEnvironment {
|
||||||
[("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect()
|
[("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect()
|
||||||
; "listcomp test")]
|
; "listcomp test")]
|
||||||
fn test_basic(source: &str, mapping: HashMap<&str, &str>) {
|
fn test_basic(source: &str, mapping: HashMap<&str, &str>) {
|
||||||
|
println!("source:\n{}", source);
|
||||||
let mut env = TestEnvironment::new();
|
let mut env = TestEnvironment::new();
|
||||||
let id_to_name = std::mem::take(&mut env.id_to_name);
|
let id_to_name = std::mem::take(&mut env.id_to_name);
|
||||||
|
let mut defined_identifiers = env.identifier_mapping.keys().cloned().collect();
|
||||||
let mut inferencer = env.get_inferencer();
|
let mut inferencer = env.get_inferencer();
|
||||||
let statements = parse_program(source).unwrap();
|
let statements = parse_program(source).unwrap();
|
||||||
statements
|
let statements = statements
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|v| inferencer.fold_stmt(v))
|
.map(|v| inferencer.fold_stmt(v))
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
|
||||||
|
|
||||||
for (k, v) in inferencer.variable_mapping.iter() {
|
for (k, v) in inferencer.variable_mapping.iter() {
|
||||||
let name = inferencer.unifier.stringify(
|
let name = inferencer.unifier.stringify(
|
||||||
*v,
|
*v,
|
||||||
|
|
Loading…
Reference in New Issue