hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
3 changed files with 179 additions and 26 deletions
Showing only changes of commit 88c45172b2 - Show all commits

View File

@ -1,20 +1,41 @@
use super::type_inferencer::Inferencer; use super::type_inferencer::Inferencer;
use super::typedef::Type; use super::typedef::Type;
use rustpython_parser::ast::{self, Expr, ExprKind, StmtKind}; use rustpython_parser::ast::{self, Expr, ExprKind, Stmt, StmtKind};
use std::iter::once; use std::iter::once;
impl<'a> Inferencer<'a> { impl<'a> Inferencer<'a> {
fn check_pattern(
&mut self,
pattern: &Expr<Option<Type>>,
defined_identifiers: &mut Vec<String>,
) {
match &pattern.node {
ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) {
defined_identifiers.push(id.clone());
}
}
ExprKind::Tuple { elts, .. } => {
for elt in elts.iter() {
self.check_pattern(elt, defined_identifiers);
}
}
_ => unimplemented!(),
}
}
fn check_expr( fn check_expr(
&mut self, &mut self,
expr: &Expr<Option<Type>>, expr: &Expr<Option<Type>>,
defined_identifiers: &[String], defined_identifiers: &[String],
) -> Result<(), String> { ) -> Result<(), String> {
// there are some cases where the custom field is None
if let Some(ty) = &expr.custom { if let Some(ty) = &expr.custom {
let ty = self.unifier.get_ty(*ty); let ty = self.unifier.get_ty(*ty);
let ty = ty.as_ref().borrow(); let ty = ty.as_ref().borrow();
if ty.is_concrete() { if !ty.is_concrete() {
return Err(format!( return Err(format!(
"expected concrete type at {:?} but got {}", "expected concrete type at {} but got {}",
expr.location, expr.location,
ty.get_type_name() ty.get_type_name()
)); ));
@ -23,7 +44,7 @@ impl<'a> Inferencer<'a> {
match &expr.node { match &expr.node {
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) { if !defined_identifiers.contains(id) {
return Err(format!("unknown identifier {} (use before def?)", id)); return Err(format!("unknown identifier {} (use before def?) at {}", id, expr.location));
} }
} }
ExprKind::List { elts, .. } ExprKind::List { elts, .. }
@ -34,14 +55,14 @@ impl<'a> Inferencer<'a> {
} }
} }
ExprKind::Attribute { value, .. } => { ExprKind::Attribute { value, .. } => {
self.check_expr(value.as_ref(), defined_identifiers)?; self.check_expr(value, defined_identifiers)?;
} }
ExprKind::BinOp { left, right, .. } => { ExprKind::BinOp { left, right, .. } => {
self.check_expr(left.as_ref(), defined_identifiers)?; self.check_expr(left, defined_identifiers)?;
self.check_expr(right.as_ref(), defined_identifiers)?; self.check_expr(right, defined_identifiers)?;
} }
ExprKind::UnaryOp { operand, .. } => { ExprKind::UnaryOp { operand, .. } => {
self.check_expr(operand.as_ref(), defined_identifiers)?; self.check_expr(operand, defined_identifiers)?;
} }
ExprKind::Compare { ExprKind::Compare {
left, comparators, .. left, comparators, ..
@ -51,13 +72,13 @@ impl<'a> Inferencer<'a> {
} }
} }
ExprKind::Subscript { value, slice, .. } => { ExprKind::Subscript { value, slice, .. } => {
self.check_expr(value.as_ref(), defined_identifiers)?; self.check_expr(value, defined_identifiers)?;
self.check_expr(slice.as_ref(), defined_identifiers)?; self.check_expr(slice, defined_identifiers)?;
} }
ExprKind::IfExp { test, body, orelse } => { ExprKind::IfExp { test, body, orelse } => {
self.check_expr(test.as_ref(), defined_identifiers)?; self.check_expr(test, defined_identifiers)?;
self.check_expr(body.as_ref(), defined_identifiers)?; self.check_expr(body, defined_identifiers)?;
self.check_expr(orelse.as_ref(), defined_identifiers)?; self.check_expr(orelse, defined_identifiers)?;
} }
ExprKind::Slice { lower, upper, step } => { ExprKind::Slice { lower, upper, step } => {
for elt in [lower.as_ref(), upper.as_ref(), step.as_ref()] for elt in [lower.as_ref(), upper.as_ref(), step.as_ref()]
@ -67,10 +88,132 @@ impl<'a> Inferencer<'a> {
self.check_expr(elt, defined_identifiers)?; self.check_expr(elt, defined_identifiers)?;
} }
} }
ExprKind::ListComp { .. } => unimplemented!(), ExprKind::Lambda { args, body } => {
ExprKind::Lambda { .. } => unimplemented!(), let mut defined_identifiers = defined_identifiers.to_vec();
_ => {} for arg in args.args.iter() {
if !defined_identifiers.contains(&arg.node.arg) {
defined_identifiers.push(arg.node.arg.clone());
}
}
self.check_expr(body, &defined_identifiers)?;
}
ExprKind::ListComp {
elt, generators, ..
} => {
// in our type inference stage, we already make sure that there is only 1 generator
let ast::Comprehension {
target, iter, ifs, ..
} = &generators[0];
self.check_expr(iter, defined_identifiers)?;
let mut defined_identifiers = defined_identifiers.to_vec();
self.check_pattern(target, &mut defined_identifiers);
for term in once(elt.as_ref()).chain(ifs.iter()) {
self.check_expr(term, &defined_identifiers)?;
}
}
ExprKind::Call {
func,
args,
keywords,
} => {
for expr in once(func.as_ref())
.chain(args.iter())
.chain(keywords.iter().map(|v| v.node.value.as_ref()))
{
self.check_expr(expr, defined_identifiers)?;
}
}
ExprKind::Constant { .. } => {}
_ => {
println!("{:?}", expr.node);
unimplemented!()
}
} }
Ok(()) Ok(())
} }
fn check_stmt(
&mut self,
stmt: &Stmt<Option<Type>>,
defined_identifiers: &mut Vec<String>,
) -> Result<bool, String> {
match &stmt.node {
StmtKind::For {
target,
iter,
body,
orelse,
..
} => {
self.check_expr(iter, defined_identifiers)?;
for stmt in orelse.iter() {
self.check_stmt(stmt, defined_identifiers)?;
}
let mut defined_identifiers = defined_identifiers.clone();
self.check_pattern(target, &mut defined_identifiers);
for stmt in body.iter() {
self.check_stmt(stmt, &mut defined_identifiers)?;
}
Ok(false)
}
StmtKind::If { test, body, orelse } => {
self.check_expr(test, defined_identifiers)?;
let mut body_identifiers = defined_identifiers.clone();
let mut orelse_identifiers = defined_identifiers.clone();
let body_returned = self.check_block(body, &mut body_identifiers)?;
let orelse_returned = self.check_block(orelse, &mut orelse_identifiers)?;
for ident in body_identifiers.iter() {
if !defined_identifiers.contains(ident) && orelse_identifiers.contains(ident) {
defined_identifiers.push(ident.clone())
}
}
Ok(body_returned && orelse_returned)
}
StmtKind::While { test, body, orelse } => {
self.check_expr(test, defined_identifiers)?;
let mut defined_identifiers = defined_identifiers.clone();
self.check_block(body, &mut defined_identifiers)?;
self.check_block(orelse, &mut defined_identifiers)?;
Ok(false)
}
StmtKind::Expr { value } => {
self.check_expr(value, defined_identifiers)?;
Ok(false)
}
StmtKind::Assign { targets, value, .. } => {
self.check_expr(value, defined_identifiers)?;
for target in targets {
self.check_pattern(target, defined_identifiers);
}
Ok(false)
}
StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value {
self.check_expr(value, defined_identifiers)?;
self.check_pattern(target, defined_identifiers);
}
Ok(false)
}
// break, return, raise, etc.
_ => Ok(false),
}
}
pub fn check_block(
&mut self,
block: &[Stmt<Option<Type>>],
defined_identifiers: &mut Vec<String>,
) -> Result<bool, String> {
let mut ret = false;
for stmt in block {
if ret {
return Err(format!("dead code at {:?}", stmt.location));
}
if self.check_stmt(stmt, defined_identifiers)? {
ret = true;
}
}
Ok(ret)
}
} }

View File

@ -107,6 +107,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
} }
} }
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {} ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
ast::StmtKind::Break | ast::StmtKind::Continue => {}
ast::StmtKind::Return { value } => match (value, self.return_type) { ast::StmtKind::Return { value } => match (value, self.return_type) {
(Some(v), Some(v1)) => { (Some(v), Some(v1)) => {
self.unifier.unify(v.custom.unwrap(), v1)?; self.unifier.unify(v.custom.unwrap(), v1)?;
@ -130,12 +131,14 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
func, func,
args, args,
keywords, keywords,
} => self.fold_call(node.location, *func, args, keywords)?, } => {
return self.fold_call(node.location, *func, args, keywords);
}
ast::ExprKind::Lambda { args, body } => { ast::ExprKind::Lambda { args, body } => {
self.fold_lambda(node.location, *args, *body)? return self.fold_lambda(node.location, *args, *body);
} }
ast::ExprKind::ListComp { elt, generators } => { ast::ExprKind::ListComp { elt, generators } => {
self.fold_listcomp(node.location, *elt, generators)? return self.fold_listcomp(node.location, *elt, generators);
} }
_ => fold::fold_expr(self, node)?, _ => fold::fold_expr(self, node)?,
}; };

View File

@ -8,12 +8,12 @@ use rustpython_parser::parser::parse_program;
use test_case::test_case; use test_case::test_case;
struct Resolver { struct Resolver {
type_mapping: HashMap<String, Type>, identifier_mapping: HashMap<String, Type>,
} }
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_symbol_type(&mut self, str: &str) -> Option<Type> { fn get_symbol_type(&mut self, str: &str) -> Option<Type> {
self.type_mapping.get(str).cloned() self.identifier_mapping.get(str).cloned()
} }
fn parse_type_name(&mut self, _: &ast::Expr<()>) -> Option<Type> { fn parse_type_name(&mut self, _: &ast::Expr<()>) -> Option<Type> {
@ -35,12 +35,13 @@ struct TestEnvironment {
pub calls: Vec<Rc<Call>>, pub calls: Vec<Rc<Call>>,
pub primitives: PrimitiveStore, pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, String>, pub id_to_name: HashMap<usize, String>,
pub identifier_mapping: HashMap<String, Type>,
} }
impl TestEnvironment { impl TestEnvironment {
fn new() -> TestEnvironment { fn new() -> TestEnvironment {
let mut unifier = Unifier::new(); let mut unifier = Unifier::new();
let mut type_mapping = HashMap::new(); let mut identifier_mapping = HashMap::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: 0, obj_id: 0,
fields: HashMap::new(), fields: HashMap::new(),
@ -66,7 +67,7 @@ impl TestEnvironment {
fields: HashMap::new(), fields: HashMap::new(),
params: HashMap::new(), params: HashMap::new(),
}); });
type_mapping.insert("None".into(), none); identifier_mapping.insert("None".into(), none);
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,
@ -84,7 +85,7 @@ impl TestEnvironment {
params: [(id, v0)].iter().cloned().collect(), params: [(id, v0)].iter().cloned().collect(),
}); });
type_mapping.insert( identifier_mapping.insert(
"Foo".into(), "Foo".into(),
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
@ -105,13 +106,14 @@ impl TestEnvironment {
.cloned() .cloned()
.collect(); .collect();
let resolver = Box::new(Resolver { type_mapping }) as Box<dyn SymbolResolver>; let resolver = Box::new(Resolver { identifier_mapping: identifier_mapping.clone() }) as Box<dyn SymbolResolver>;
TestEnvironment { TestEnvironment {
unifier, unifier,
resolver, resolver,
primitives, primitives,
id_to_name, id_to_name,
identifier_mapping,
calls: Vec::new(), calls: Vec::new(),
} }
} }
@ -168,15 +170,20 @@ impl TestEnvironment {
[("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect() [("a", "list[int32]"), ("b", "list[bool]"), ("f", "fn[[x=int32], bool]")].iter().cloned().collect()
; "listcomp test")] ; "listcomp test")]
fn test_basic(source: &str, mapping: HashMap<&str, &str>) { fn test_basic(source: &str, mapping: HashMap<&str, &str>) {
println!("source:\n{}", source);
let mut env = TestEnvironment::new(); let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name); let id_to_name = std::mem::take(&mut env.id_to_name);
let mut defined_identifiers = env.identifier_mapping.keys().cloned().collect();
let mut inferencer = env.get_inferencer(); let mut inferencer = env.get_inferencer();
let statements = parse_program(source).unwrap(); let statements = parse_program(source).unwrap();
statements let statements = statements
.into_iter() .into_iter()
.map(|v| inferencer.fold_stmt(v)) .map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
.unwrap(); .unwrap();
inferencer.check_block(&statements, &mut defined_identifiers).unwrap();
for (k, v) in inferencer.variable_mapping.iter() { for (k, v) in inferencer.variable_mapping.iter() {
let name = inferencer.unifier.stringify( let name = inferencer.unifier.stringify(
*v, *v,