hm-inference #6

Merged
sb10q merged 136 commits from hm-inference into master 2021-08-19 11:46:50 +08:00
5 changed files with 200 additions and 24 deletions
Showing only changes of commit 2f5c3b3cb7 - Show all commits

View File

@ -2,11 +2,6 @@ use super::typedef::Type;
use super::location::Location; use super::location::Location;
use rustpython_parser::ast::Expr; use rustpython_parser::ast::Expr;
pub enum SymbolType {
TypeName(Type),
Identifier(Type),
}
pub enum SymbolValue<'a> { pub enum SymbolValue<'a> {
I32(i32), I32(i32),
I64(i64), I64(i64),

View File

@ -14,20 +14,23 @@ use rustpython_parser::ast::{
Arguments, Comprehension, ExprKind, Located, Location, Arguments, Comprehension, ExprKind, Located, Location,
}; };
#[cfg(test)]
mod test;
pub struct PrimitiveStore { pub struct PrimitiveStore {
int32: Type, pub int32: Type,
int64: Type, pub int64: Type,
float: Type, pub float: Type,
bool: Type, pub bool: Type,
none: Type, pub none: Type,
} }
pub struct Inferencer<'a> { pub struct Inferencer<'a> {
resolver: &'a mut Box<dyn SymbolResolver>, pub resolver: &'a mut Box<dyn SymbolResolver>,
unifier: &'a mut Unifier, pub unifier: &'a mut Unifier,
variable_mapping: HashMap<String, Type>, pub variable_mapping: HashMap<String, Type>,
calls: &'a mut Vec<Rc<Call>>, pub calls: &'a mut Vec<Rc<Call>>,
primitives: &'a PrimitiveStore, pub primitives: &'a PrimitiveStore,
} }
struct NaiveFolder(); struct NaiveFolder();
@ -69,7 +72,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
.resolver .resolver
.parse_type_name(annotation.as_ref()) .parse_type_name(annotation.as_ref())
.ok_or_else(|| "cannot parse type name".to_string())?; .ok_or_else(|| "cannot parse type name".to_string())?;
self.unifier.unify(annotation_type, target.custom.unwrap())?; self.unifier
.unify(annotation_type, target.custom.unwrap())?;
let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?); let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?);
Located { Located {
location: node.location, location: node.location,
@ -102,7 +106,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
} }
} }
ast::StmtKind::AnnAssign { .. } => {} ast::StmtKind::AnnAssign { .. } => {}
_ => return Err("Unsupported statement type".to_string()) _ => return Err("Unsupported statement type".to_string()),
}; };
Ok(stmt) Ok(stmt)
} }
@ -358,7 +362,7 @@ impl<'a> Inferencer<'a> {
if id == "int64" && args.len() == 1 { if id == "int64" && args.len() == 1 {
if let ExprKind::Constant { if let ExprKind::Constant {
value: ast::Constant::Int(val), value: ast::Constant::Int(val),
.. kind,
} = &args[0].node } = &args[0].node
{ {
let int64: Result<i64, _> = val.try_into(); let int64: Result<i64, _> = val.try_into();
@ -377,7 +381,14 @@ impl<'a> Inferencer<'a> {
location: func.location, location: func.location,
node: ExprKind::Name { id, ctx }, node: ExprKind::Name { id, ctx },
}), }),
args: vec![self.fold_expr(args.pop().unwrap())?], args: vec![Located {
location: args[0].location,
custom,
node: ExprKind::Constant {
value: ast::Constant::Int(val.clone()),
kind: kind.clone(),
},
}],
keywords: vec![], keywords: vec![],
}, },
}); });

View File

@ -0,0 +1,163 @@
use super::super::location::Location;
use super::super::symbol_resolver::*;
use super::super::typedef::*;
use super::*;
use indoc::indoc;
use rustpython_parser::ast;
use rustpython_parser::parser::parse_program;
use test_case::test_case;
struct Resolver {
type_mapping: HashMap<String, Type>,
}
impl SymbolResolver for Resolver {
fn get_symbol_type(&mut self, str: &str) -> Option<Type> {
self.type_mapping.get(str).cloned()
}
fn parse_type_name(&mut self, _: &ast::Expr<()>) -> Option<Type> {
unimplemented!()
}
fn get_symbol_value(&mut self, _: &str) -> Option<SymbolValue> {
unimplemented!()
}
fn get_symbol_location(&mut self, _: &str) -> Option<Location> {
unimplemented!()
}
}
struct TestEnvironment {
pub unifier: Unifier,
pub resolver: Box<dyn SymbolResolver>,
pub calls: Vec<Rc<Call>>,
pub primitives: PrimitiveStore,
pub id_to_name: HashMap<usize, String>,
}
impl TestEnvironment {
fn new() -> TestEnvironment {
let mut unifier = Unifier::new();
let mut type_mapping = HashMap::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: 0,
fields: HashMap::new(),
params: HashMap::new(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: 1,
fields: HashMap::new(),
params: HashMap::new(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: 2,
fields: HashMap::new(),
params: HashMap::new(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: 3,
fields: HashMap::new(),
params: HashMap::new(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: 4,
fields: HashMap::new(),
params: HashMap::new(),
});
type_mapping.insert("int32".into(), int32);
type_mapping.insert("int64".into(), int64);
type_mapping.insert("float".into(), float);
type_mapping.insert("bool".into(), bool);
type_mapping.insert("none".into(), none);
let primitives = PrimitiveStore {
int32,
int64,
float,
bool,
none,
};
let (v0, id) = unifier.get_fresh_var();
type_mapping.insert(
"foo".into(),
unifier.add_ty(TypeEnum::TObj {
obj_id: 5,
fields: [("a".into(), v0)].iter().cloned().collect(),
params: [(id, v0)].iter().cloned().collect(),
}),
);
let id_to_name = [
(0, "int32".to_string()),
(1, "int64".to_string()),
(2, "float".to_string()),
(3, "bool".to_string()),
(4, "none".to_string()),
(5, "Foo".to_string()),
]
.iter()
.cloned()
.collect();
let resolver = Box::new(Resolver { type_mapping }) as Box<dyn SymbolResolver>;
TestEnvironment {
unifier,
resolver,
primitives,
id_to_name,
calls: Vec::new(),
}
}
fn get_inferencer(&mut self) -> Inferencer {
Inferencer {
resolver: &mut self.resolver,
unifier: &mut self.unifier,
variable_mapping: Default::default(),
calls: &mut self.calls,
primitives: &mut self.primitives,
}
}
}
#[test_case(indoc! {"
a = 1234
b = int64(2147483648)
c = 1.234
d = True
"},
[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect()
; "primitives test")]
#[test_case(indoc! {"
a = lambda x, y: x
b = lambda x: a(x, x)
c = 1.234
d = b(c)
"},
[("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect()
; "lambda test")]
fn test_basic(source: &str, mapping: HashMap<&str, &str>) {
let mut env = TestEnvironment::new();
let id_to_name = std::mem::take(&mut env.id_to_name);
let mut inferencer = env.get_inferencer();
let statements = parse_program(source).unwrap();
statements
.into_iter()
.map(|v| inferencer.fold_stmt(v))
.collect::<Result<Vec<_>, _>>()
.unwrap();
for (k, v) in mapping.iter() {
let ty = inferencer.variable_mapping.get(*k).unwrap();
let name = inferencer.unifier.stringify(
*ty,
&mut |v| id_to_name.get(&v).unwrap().clone(),
&mut |v| format!("v{}", v),
);
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
}
}

View File

@ -8,7 +8,7 @@ use std::ops::Deref;
use std::rc::Rc; use std::rc::Rc;
#[cfg(test)] #[cfg(test)]
mod test_typedef; mod test;
#[derive(Copy, Clone, PartialEq, Eq, Debug)] #[derive(Copy, Clone, PartialEq, Eq, Debug)]
/// Handle for a type, implementated as a key in the unification table. /// Handle for a type, implementated as a key in the unification table.
@ -217,10 +217,14 @@ impl Unifier {
} }
TypeEnum::TObj { obj_id, params, .. } => { TypeEnum::TObj { obj_id, params, .. } => {
let name = obj_to_name(*obj_id); let name = obj_to_name(*obj_id);
if params.len() > 0 {
let mut params = params let mut params = params
.values() .values()
.map(|v| self.stringify(*v, obj_to_name, var_to_name)); .map(|v| self.stringify(*v, obj_to_name, var_to_name));
format!("{}[{}]", name, params.join(", ")) format!("{}[{}]", name, params.join(", "))
} else {
name
}
} }
TypeEnum::TCall { .. } => "call".to_owned(), TypeEnum::TCall { .. } => "call".to_owned(),
TypeEnum::TFunc(signature) => { TypeEnum::TFunc(signature) => {
@ -432,6 +436,9 @@ impl Unifier {
return Err(format!("Unknown keyword argument {}", k)); return Err(format!("Unknown keyword argument {}", k));
} }
} }
if !required.is_empty() {
return Err("Expected more arguments".to_string());
}
self.unify(*ret, signature.ret)?; self.unify(*ret, signature.ret)?;
*fun.borrow_mut() = Some(instantiated); *fun.borrow_mut() = Some(instantiated);
} }