hm-inference #6
|
@ -105,7 +105,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
.unify(target.custom.unwrap(), value.custom.unwrap())?;
|
.unify(target.custom.unwrap(), value.custom.unwrap())?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ast::StmtKind::AnnAssign { .. } => {}
|
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
||||||
_ => return Err("Unsupported statement type".to_string()),
|
_ => return Err("Unsupported statement type".to_string()),
|
||||||
};
|
};
|
||||||
Ok(stmt)
|
Ok(stmt)
|
||||||
|
|
|
@ -66,11 +66,7 @@ impl TestEnvironment {
|
||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: HashMap::new(),
|
params: HashMap::new(),
|
||||||
});
|
});
|
||||||
type_mapping.insert("int32".into(), int32);
|
type_mapping.insert("None".into(), none);
|
||||||
type_mapping.insert("int64".into(), int64);
|
|
||||||
type_mapping.insert("float".into(), float);
|
|
||||||
type_mapping.insert("bool".into(), bool);
|
|
||||||
type_mapping.insert("none".into(), none);
|
|
||||||
|
|
||||||
let primitives = PrimitiveStore {
|
let primitives = PrimitiveStore {
|
||||||
int32,
|
int32,
|
||||||
|
@ -81,13 +77,20 @@ impl TestEnvironment {
|
||||||
};
|
};
|
||||||
|
|
||||||
let (v0, id) = unifier.get_fresh_var();
|
let (v0, id) = unifier.get_fresh_var();
|
||||||
|
|
||||||
|
let foo_ty = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: 5,
|
||||||
|
fields: [("a".into(), v0)].iter().cloned().collect(),
|
||||||
|
params: [(id, v0)].iter().cloned().collect(),
|
||||||
|
});
|
||||||
|
|
||||||
type_mapping.insert(
|
type_mapping.insert(
|
||||||
"foo".into(),
|
"Foo".into(),
|
||||||
unifier.add_ty(TypeEnum::TObj {
|
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
obj_id: 5,
|
args: vec![],
|
||||||
fields: [("a".into(), v0)].iter().cloned().collect(),
|
ret: foo_ty,
|
||||||
params: [(id, v0)].iter().cloned().collect(),
|
vars: [(id, v0)].iter().cloned().collect(),
|
||||||
}),
|
})),
|
||||||
);
|
);
|
||||||
|
|
||||||
let id_to_name = [
|
let id_to_name = [
|
||||||
|
@ -140,6 +143,22 @@ impl TestEnvironment {
|
||||||
"},
|
"},
|
||||||
[("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect()
|
[("a", "fn[[x=float, y=float], float]"), ("b", "fn[[x=float], float]"), ("c", "float"), ("d", "float")].iter().cloned().collect()
|
||||||
; "lambda test")]
|
; "lambda test")]
|
||||||
|
#[test_case(indoc! {"
|
||||||
|
a = lambda x: x
|
||||||
|
b = lambda x: x
|
||||||
|
|
||||||
|
foo1 = Foo()
|
||||||
|
foo2 = Foo()
|
||||||
|
c = a(foo1.a)
|
||||||
|
d = b(foo2.a)
|
||||||
|
|
||||||
|
a(True)
|
||||||
|
b(123)
|
||||||
|
|
||||||
|
"},
|
||||||
|
[("a", "fn[[x=bool], bool]"), ("b", "fn[[x=int32], int32]"), ("c", "bool"),
|
||||||
|
("d", "int32"), ("foo1", "Foo[bool]"), ("foo2", "Foo[int32]")].iter().cloned().collect()
|
||||||
|
; "obj test")]
|
||||||
fn test_basic(source: &str, mapping: HashMap<&str, &str>) {
|
fn test_basic(source: &str, mapping: HashMap<&str, &str>) {
|
||||||
let mut env = TestEnvironment::new();
|
let mut env = TestEnvironment::new();
|
||||||
let id_to_name = std::mem::take(&mut env.id_to_name);
|
let id_to_name = std::mem::take(&mut env.id_to_name);
|
||||||
|
@ -150,6 +169,14 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>) {
|
||||||
.map(|v| inferencer.fold_stmt(v))
|
.map(|v| inferencer.fold_stmt(v))
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
for (k, v) in inferencer.variable_mapping.iter() {
|
||||||
|
let name = inferencer.unifier.stringify(
|
||||||
|
*v,
|
||||||
|
&mut |v| id_to_name.get(&v).unwrap().clone(),
|
||||||
|
&mut |v| format!("v{}", v),
|
||||||
|
);
|
||||||
|
println!("{}: {}", k, name);
|
||||||
|
}
|
||||||
for (k, v) in mapping.iter() {
|
for (k, v) in mapping.iter() {
|
||||||
let ty = inferencer.variable_mapping.get(*k).unwrap();
|
let ty = inferencer.variable_mapping.get(*k).unwrap();
|
||||||
let name = inferencer.unifier.stringify(
|
let name = inferencer.unifier.stringify(
|
||||||
|
@ -160,4 +187,3 @@ fn test_basic(source: &str, mapping: HashMap<&str, &str>) {
|
||||||
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
|
assert_eq!(format!("{}: {}", k, v), format!("{}: {}", k, name));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue