hm-inference #6
|
@ -2,11 +2,6 @@ use super::typedef::Type;
|
|||
use super::location::Location;
|
||||
use rustpython_parser::ast::Expr;
|
||||
|
||||
pub enum SymbolType {
|
||||
TypeName(Type),
|
||||
Identifier(Type),
|
||||
}
|
||||
|
||||
pub enum SymbolValue<'a> {
|
||||
I32(i32),
|
||||
I64(i64),
|
||||
|
|
|
@ -14,20 +14,23 @@ use rustpython_parser::ast::{
|
|||
Arguments, Comprehension, ExprKind, Located, Location,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
|
||||
pub struct PrimitiveStore {
|
||||
int32: Type,
|
||||
int64: Type,
|
||||
float: Type,
|
||||
bool: Type,
|
||||
none: Type,
|
||||
pub int32: Type,
|
||||
pub int64: Type,
|
||||
pub float: Type,
|
||||
pub bool: Type,
|
||||
pub none: Type,
|
||||
}
|
||||
|
||||
pub struct Inferencer<'a> {
|
||||
resolver: &'a mut Box<dyn SymbolResolver>,
|
||||
unifier: &'a mut Unifier,
|
||||
variable_mapping: HashMap<String, Type>,
|
||||
calls: &'a mut Vec<Rc<Call>>,
|
||||
primitives: &'a PrimitiveStore,
|
||||
pub resolver: &'a mut Box<dyn SymbolResolver>,
|
||||
pub unifier: &'a mut Unifier,
|
||||
pub variable_mapping: HashMap<String, Type>,
|
||||
pub calls: &'a mut Vec<Rc<Call>>,
|
||||
pub primitives: &'a PrimitiveStore,
|
||||
}
|
||||
|
||||
struct NaiveFolder();
|
||||
|
@ -69,7 +72,8 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
.resolver
|
||||
.parse_type_name(annotation.as_ref())
|
||||
.ok_or_else(|| "cannot parse type name".to_string())?;
|
||||
self.unifier.unify(annotation_type, target.custom.unwrap())?;
|
||||
self.unifier
|
||||
.unify(annotation_type, target.custom.unwrap())?;
|
||||
let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?);
|
||||
Located {
|
||||
location: node.location,
|
||||
|
@ -102,7 +106,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
}
|
||||
}
|
||||
ast::StmtKind::AnnAssign { .. } => {}
|
||||
_ => return Err("Unsupported statement type".to_string())
|
||||
_ => return Err("Unsupported statement type".to_string()),
|
||||
};
|
||||
Ok(stmt)
|
||||
}
|
||||
|
@ -358,7 +362,7 @@ impl<'a> Inferencer<'a> {
|
|||
if id == "int64" && args.len() == 1 {
|
||||
if let ExprKind::Constant {
|
||||
value: ast::Constant::Int(val),
|
||||
..
|
||||
kind,
|
||||
} = &args[0].node
|
||||
{
|
||||
let int64: Result<i64, _> = val.try_into();
|
||||
|
@ -377,7 +381,14 @@ impl<'a> Inferencer<'a> {
|
|||
location: func.location,
|
||||
node: ExprKind::Name { id, ctx },
|
||||
}),
|
||||
args: vec![self.fold_expr(args.pop().unwrap())?],
|
||||
args: vec![Located {
|
||||
location: args[0].location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: ast::Constant::Int(val.clone()),
|
||||
kind: kind.clone(),
|
||||
},
|
||||
}],
|
||||
keywords: vec![],
|
||||
},
|
||||
});
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
use super::super::location::Location;
|
||||
use super::super::symbol_resolver::*;
|
||||
use super::super::typedef::*;
|
||||
use super::*;
|
||||
use indoc::indoc;
|
||||
use rustpython_parser::ast;
|
||||
use rustpython_parser::parser::parse_program;
|
||||
use test_case::test_case;
|
||||
|
||||
struct Resolver {
|
||||
type_mapping: HashMap<String, Type>,
|
||||
}
|
||||
|
||||
impl SymbolResolver for Resolver {
|
||||
fn get_symbol_type(&mut self, str: &str) -> Option<Type> {
|
||||
self.type_mapping.get(str).cloned()
|
||||
}
|
||||
|
||||
fn parse_type_name(&mut self, _: &ast::Expr<()>) -> Option<Type> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_symbol_value(&mut self, _: &str) -> Option<SymbolValue> {
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
fn get_symbol_location(&mut self, _: &str) -> Option<Location> {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
struct TestEnvironment {
|
||||
pub unifier: Unifier,
|
||||
pub resolver: Box<dyn SymbolResolver>,
|
||||
pub calls: Vec<Rc<Call>>,
|
||||
pub primitives: PrimitiveStore,
|
||||
pub id_to_name: HashMap<usize, String>,
|
||||
}
|
||||
|
||||
impl TestEnvironment {
|
||||
fn new() -> TestEnvironment {
|
||||
let mut unifier = Unifier::new();
|
||||
let mut type_mapping = HashMap::new();
|
||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 0,
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
});
|
||||
let int64 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 1,
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
});
|
||||
let float = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 2,
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
});
|
||||
let bool = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 3,
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
});
|
||||
let none = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 4,
|
||||
fields: HashMap::new(),
|
||||
params: HashMap::new(),
|
||||
});
|
||||
type_mapping.insert("int32".into(), int32);
|
||||
type_mapping.insert("int64".into(), int64);
|
||||
type_mapping.insert("float".into(), float);
|
||||
type_mapping.insert("bool".into(), bool);
|
||||
type_mapping.insert("none".into(), none);
|
||||
|
||||
let primitives = PrimitiveStore {
|
||||
int32,
|
||||
int64,
|
||||
float,
|
||||
bool,
|
||||
none,
|
||||
};
|
||||
|
||||
let (v0, id) = unifier.get_fresh_var();
|
||||
type_mapping.insert(
|
||||
"foo".into(),
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: 5,
|
||||
fields: [("a".into(), v0)].iter().cloned().collect(),
|
||||
params: [(id, v0)].iter().cloned().collect(),
|
||||
}),
|
||||
);
|
||||
|
||||
let id_to_name = [
|
||||
(0, "int32".to_string()),
|
||||
(1, "int64".to_string()),
|
||||
(2, "float".to_string()),
|
||||
(3, "bool".to_string()),
|
||||
(4, "none".to_string()),
|
||||
(5, "Foo".to_string()),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let resolver = Box::new(Resolver { type_mapping }) as Box<dyn SymbolResolver>;
|
||||
|
||||
TestEnvironment {
|
||||
unifier,
|
||||
resolver,
|
||||
primitives,
|
||||
id_to_name,
|
||||
calls: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_inferencer(&mut self) -> Inferencer {
|
||||
Inferencer {
|
||||
resolver: &mut self.resolver,
|
||||
unifier: &mut self.unifier,
|
||||
variable_mapping: Default::default(),
|
||||
calls: &mut self.calls,
|
||||
primitives: &mut self.primitives,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test_case(indoc! {"
|
||||
a = 1234
|
||||
b = int64(2147483648)
|
||||
c = 1.234
|
||||
d = True
|
||||
"},
|
||||
[("a", "int32"), ("b", "int64"), ("c", "float"), ("d", "bool")].iter().cloned().collect()
|
||||
; "primitives test")]
|
||||
#[test_case(indoc! {"
|
||||
a = lambda x, y: x
|
||||
b = lambda x: a(x, x)
|
||||
c = 1.234
|
||||
d = b(c)
|
||||
"},
|
||||
[("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect()
|
||||
; "lambda test")]
|
||||
fn test_basic(source: &str, mapping: HashMap<&str, &str>) {
|
||||
let mut env = TestEnvironment::new();
|
||||
let id_to_name = std::mem::take(&mut env.id_to_name);
|
||||
let mut inferencer = env.get_inferencer();
|
||||
let statements = parse_program(source).unwrap();
|
||||
statements
|
||||
.into_iter()
|
||||
.map(|v| inferencer.fold_stmt(v))
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.unwrap();
|
||||
for (k, v) in mapping.iter() {
|
||||
let ty = inferencer.variable_mapping.get(*k).unwrap();
|
||||
let name = inferencer.unifier.stringify(
|
||||
*ty,
|
||||
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||
&mut |v| format!("v{}", v),
|
||||
);
|
||||
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
|
||||
}
|
||||
}
|
||||
|
|
@ -8,7 +8,7 @@ use std::ops::Deref;
|
|||
use std::rc::Rc;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_typedef;
|
||||
mod test;
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
|
||||
/// Handle for a type, implementated as a key in the unification table.
|
||||
|
@ -217,10 +217,14 @@ impl Unifier {
|
|||
}
|
||||
TypeEnum::TObj { obj_id, params, .. } => {
|
||||
let name = obj_to_name(*obj_id);
|
||||
if params.len() > 0 {
|
||||
let mut params = params
|
||||
.values()
|
||||
.map(|v| self.stringify(*v, obj_to_name, var_to_name));
|
||||
format!("{}[{}]", name, params.join(", "))
|
||||
} else {
|
||||
name
|
||||
}
|
||||
}
|
||||
TypeEnum::TCall { .. } => "call".to_owned(),
|
||||
TypeEnum::TFunc(signature) => {
|
||||
|
@ -432,6 +436,9 @@ impl Unifier {
|
|||
return Err(format!("Unknown keyword argument {}", k));
|
||||
}
|
||||
}
|
||||
if !required.is_empty() {
|
||||
return Err("Expected more arguments".to_string());
|
||||
}
|
||||
self.unify(*ret, signature.ret)?;
|
||||
*fun.borrow_mut() = Some(instantiated);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue