forked from M-Labs/nac3
nac3core: top level inferencer call with type var more test
This commit is contained in:
parent
41e63f24d0
commit
ed5dfd4100
@ -1151,9 +1151,16 @@ impl TopLevelComposer {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let type_var_subst_comb = {
|
let (type_var_subst_comb, no_range_vars) = {
|
||||||
let unifier = &mut self.unifier;
|
let unifier = &mut self.unifier;
|
||||||
let var_ids = vars.iter().map(|(id, _)| *id);
|
let mut no_ranges: Vec<Type> = Vec::new();
|
||||||
|
let var_ids = vars.iter().map(|(id, ty)| {
|
||||||
|
if matches!(unifier.get_ty(*ty).as_ref(), TypeEnum::TVar { range, .. } if range.borrow().is_empty()) {
|
||||||
|
no_ranges.push(*ty);
|
||||||
|
}
|
||||||
|
*id
|
||||||
|
})
|
||||||
|
.collect_vec();
|
||||||
let var_combs = vars
|
let var_combs = vars
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(_, ty)| {
|
.map(|(_, ty)| {
|
||||||
@ -1163,28 +1170,33 @@ impl TopLevelComposer {
|
|||||||
.collect_vec();
|
.collect_vec();
|
||||||
let mut result: Vec<HashMap<u32, Type>> = Default::default();
|
let mut result: Vec<HashMap<u32, Type>> = Default::default();
|
||||||
for comb in var_combs {
|
for comb in var_combs {
|
||||||
result.push(var_ids.clone().zip(comb).collect());
|
result.push(var_ids.clone().into_iter().zip(comb).collect());
|
||||||
}
|
}
|
||||||
// NOTE: if is empty, means no type var, append a empty subst, ok to do this?
|
// NOTE: if is empty, means no type var, append a empty subst, ok to do this?
|
||||||
if result.is_empty() {
|
if result.is_empty() {
|
||||||
result.push(HashMap::new())
|
result.push(HashMap::new())
|
||||||
}
|
}
|
||||||
result
|
(result, no_ranges)
|
||||||
};
|
};
|
||||||
|
|
||||||
for subst in type_var_subst_comb {
|
for subst in type_var_subst_comb {
|
||||||
// for each instance
|
// for each instance
|
||||||
let unifier = &mut self.unifier;
|
let inst_ret = self.unifier.subst(*ret, &subst).unwrap_or(*ret);
|
||||||
let inst_ret = unifier.subst(*ret, &subst).unwrap_or(*ret);
|
let inst_args = {
|
||||||
let inst_args = args
|
let unifier = &mut self.unifier;
|
||||||
.iter()
|
args
|
||||||
.map(|a| FuncArg {
|
.iter()
|
||||||
name: a.name.clone(),
|
.map(|a| FuncArg {
|
||||||
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
|
name: a.name.clone(),
|
||||||
default_value: a.default_value.clone(),
|
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
|
||||||
})
|
default_value: a.default_value.clone(),
|
||||||
.collect_vec();
|
})
|
||||||
let self_type = self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x));
|
.collect_vec()
|
||||||
|
};
|
||||||
|
let self_type = {
|
||||||
|
let unifier = &mut self.unifier;
|
||||||
|
self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x))
|
||||||
|
};
|
||||||
|
|
||||||
let mut identifiers = {
|
let mut identifiers = {
|
||||||
// NOTE: none and function args?
|
// NOTE: none and function args?
|
||||||
@ -1196,38 +1208,36 @@ impl TopLevelComposer {
|
|||||||
result.extend(inst_args.iter().map(|x| x.name.clone()));
|
result.extend(inst_args.iter().map(|x| x.name.clone()));
|
||||||
result
|
result
|
||||||
};
|
};
|
||||||
let mut inferencer = {
|
let mut inferencer = Inferencer {
|
||||||
Inferencer {
|
top_level: &self.make_top_level_context(),
|
||||||
top_level: &self.make_top_level_context(),
|
defined_identifiers: identifiers.clone(),
|
||||||
defined_identifiers: identifiers.clone(),
|
function_data: &mut FunctionData {
|
||||||
function_data: &mut FunctionData {
|
resolver: resolver.as_ref().unwrap().clone(),
|
||||||
resolver: resolver.as_ref().unwrap().clone(),
|
return_type: if self
|
||||||
return_type: if self
|
.unifier
|
||||||
.unifier
|
.unioned(inst_ret, self.primitives_ty.none)
|
||||||
.unioned(inst_ret, self.primitives_ty.none)
|
{
|
||||||
{
|
None
|
||||||
None
|
} else {
|
||||||
} else {
|
Some(inst_ret)
|
||||||
Some(inst_ret)
|
|
||||||
},
|
|
||||||
// NOTE: allowed type vars: leave blank?
|
|
||||||
bound_variables: Vec::new(),
|
|
||||||
},
|
},
|
||||||
unifier: &mut self.unifier,
|
// NOTE: allowed type vars
|
||||||
variable_mapping: {
|
bound_variables: no_range_vars.clone(),
|
||||||
// NOTE: none and function args?
|
},
|
||||||
let mut result: HashMap<String, Type> = HashMap::new();
|
unifier: &mut self.unifier,
|
||||||
result.insert("None".into(), self.primitives_ty.none);
|
variable_mapping: {
|
||||||
if let Some(self_ty) = self_type {
|
// NOTE: none and function args?
|
||||||
result.insert("self".into(), self_ty);
|
let mut result: HashMap<String, Type> = HashMap::new();
|
||||||
}
|
result.insert("None".into(), self.primitives_ty.none);
|
||||||
result.extend(inst_args.iter().map(|x| (x.name.clone(), x.ty)));
|
if let Some(self_ty) = self_type {
|
||||||
result
|
result.insert("self".into(), self_ty);
|
||||||
},
|
}
|
||||||
primitives: &self.primitives_ty,
|
result.extend(inst_args.iter().map(|x| (x.name.clone(), x.ty)));
|
||||||
virtual_checks: &mut Vec::new(),
|
result
|
||||||
calls: &mut HashMap::new(),
|
},
|
||||||
}
|
primitives: &self.primitives_ty,
|
||||||
|
virtual_checks: &mut Vec::new(),
|
||||||
|
calls: &mut HashMap::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let fun_body = if let ast::StmtKind::FunctionDef { body, .. } =
|
let fun_body = if let ast::StmtKind::FunctionDef { body, .. } =
|
||||||
@ -1257,8 +1267,17 @@ impl TopLevelComposer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
instance_to_stmt.insert(
|
instance_to_stmt.insert(
|
||||||
// FIXME: how?
|
// NOTE: refer to codegen/expr/get_subst_key function
|
||||||
"".to_string(),
|
{
|
||||||
|
let unifier = &mut self.unifier;
|
||||||
|
subst
|
||||||
|
.keys()
|
||||||
|
.sorted()
|
||||||
|
.map(|id| {
|
||||||
|
let ty = subst.get(id).unwrap();
|
||||||
|
unifier.stringify(*ty, &mut |id| id.to_string(), &mut |id| id.to_string())
|
||||||
|
}).join(", ")
|
||||||
|
},
|
||||||
FunInstance {
|
FunInstance {
|
||||||
body: fun_body,
|
body: fun_body,
|
||||||
unifier_id: 0,
|
unifier_id: 0,
|
||||||
|
@ -37,7 +37,7 @@ impl SymbolResolver for Resolver {
|
|||||||
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
|
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
|
||||||
let ret = self.0.id_to_type.lock().get(str).cloned();
|
let ret = self.0.id_to_type.lock().get(str).cloned();
|
||||||
if ret.is_none() {
|
if ret.is_none() {
|
||||||
println!("unknown here resolver {}", str);
|
// println!("unknown here resolver {}", str);
|
||||||
}
|
}
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
@ -772,23 +772,15 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
|||||||
let print = false;
|
let print = false;
|
||||||
let mut composer = TopLevelComposer::new();
|
let mut composer = TopLevelComposer::new();
|
||||||
|
|
||||||
let tvar_t = composer.unifier.get_fresh_var();
|
let internal_resolver = make_internal_resolver_with_tvar(
|
||||||
let tvar_v = composer
|
vec![
|
||||||
.unifier
|
("T".into(), vec![]),
|
||||||
.get_fresh_var_with_range(&[composer.primitives_ty.bool, composer.primitives_ty.int32]);
|
("V".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int32]),
|
||||||
|
("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]),
|
||||||
if print {
|
],
|
||||||
println!("t: {}, {:?}", tvar_t.1, tvar_t.0);
|
&mut composer.unifier,
|
||||||
println!("v: {}, {:?}\n", tvar_v.1, tvar_v.0);
|
print
|
||||||
}
|
);
|
||||||
|
|
||||||
let internal_resolver = Arc::new(ResolverInternal {
|
|
||||||
id_to_def: Default::default(),
|
|
||||||
id_to_type: Mutex::new(
|
|
||||||
vec![("T".to_string(), tvar_t.0), ("V".to_string(), tvar_v.0)].into_iter().collect(),
|
|
||||||
),
|
|
||||||
class_names: Default::default(),
|
|
||||||
});
|
|
||||||
let resolver = Arc::new(
|
let resolver = Arc::new(
|
||||||
Box::new(Resolver(internal_resolver.clone())) as Box<dyn SymbolResolver + Send + Sync>
|
Box::new(Resolver(internal_resolver.clone())) as Box<dyn SymbolResolver + Send + Sync>
|
||||||
);
|
);
|
||||||
@ -888,7 +880,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
|||||||
return SELF
|
return SELF
|
||||||
def sum(self) -> int32:
|
def sum(self) -> int32:
|
||||||
if self.a == 0:
|
if self.a == 0:
|
||||||
return self.a
|
return self.a + self
|
||||||
else:
|
else:
|
||||||
a = self.a
|
a = self.a
|
||||||
self.a = self.a - 1
|
self.a = self.a - 1
|
||||||
@ -909,34 +901,58 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
|||||||
#[test_case(
|
#[test_case(
|
||||||
vec![
|
vec![
|
||||||
indoc! {"
|
indoc! {"
|
||||||
def fun(a: V) -> V:
|
def fun(a: V, c: G, t: T) -> V:
|
||||||
b = a
|
b = a
|
||||||
return a
|
cc = c
|
||||||
|
ret = fun(b, cc, t)
|
||||||
|
return ret * ret
|
||||||
|
"},
|
||||||
|
indoc! {"
|
||||||
|
def sum3(l: list[V]) -> V:
|
||||||
|
return l[0] + l[1] + l[2]
|
||||||
|
"},
|
||||||
|
indoc! {"
|
||||||
|
def sum_sq_pair(p: tuple[V, V]) -> list[V]:
|
||||||
|
a = p[0]
|
||||||
|
b = p[1]
|
||||||
|
a = a**a
|
||||||
|
b = b**b
|
||||||
|
return [a, b]
|
||||||
"}
|
"}
|
||||||
],
|
],
|
||||||
vec![];
|
vec![];
|
||||||
"type var fun"
|
"type var fun"
|
||||||
)]
|
)]
|
||||||
|
#[test_case(
|
||||||
|
vec![
|
||||||
|
indoc! {"
|
||||||
|
class A(Generic[G]):
|
||||||
|
a: G
|
||||||
|
b: bool
|
||||||
|
def __init__(self, aa: G):
|
||||||
|
self.a = aa
|
||||||
|
self.b = True
|
||||||
|
def fun(self, a: G) -> list[G]:
|
||||||
|
ret = [a, self.a]
|
||||||
|
return ret if self.b else self.fun(self.a)
|
||||||
|
"}
|
||||||
|
],
|
||||||
|
vec![];
|
||||||
|
"type var class"
|
||||||
|
)]
|
||||||
fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||||
let print = true;
|
let print = true;
|
||||||
let mut composer = TopLevelComposer::new();
|
let mut composer = TopLevelComposer::new();
|
||||||
|
|
||||||
let tvar_t = composer.unifier.get_fresh_var();
|
let internal_resolver = make_internal_resolver_with_tvar(
|
||||||
let tvar_v = composer
|
vec![
|
||||||
.unifier
|
("T".into(), vec![]),
|
||||||
.get_fresh_var_with_range(&[composer.primitives_ty.bool, composer.primitives_ty.int64]);
|
("V".into(), vec![composer.primitives_ty.float, composer.primitives_ty.int32, composer.primitives_ty.int64]),
|
||||||
if print {
|
("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]),
|
||||||
println!("t: {}, {:?}", tvar_t.1, tvar_t.0);
|
],
|
||||||
println!("v: {}, {:?}\n", tvar_v.1, tvar_v.0);
|
&mut composer.unifier,
|
||||||
}
|
print
|
||||||
|
);
|
||||||
let internal_resolver = Arc::new(ResolverInternal {
|
|
||||||
id_to_def: Default::default(),
|
|
||||||
id_to_type: Mutex::new(
|
|
||||||
vec![("T".to_string(), tvar_t.0), ("V".to_string(), tvar_v.0)].into_iter().collect(),
|
|
||||||
),
|
|
||||||
class_names: Default::default(),
|
|
||||||
});
|
|
||||||
let resolver = Arc::new(
|
let resolver = Arc::new(
|
||||||
Box::new(Resolver(internal_resolver.clone())) as Box<dyn SymbolResolver + Send + Sync>
|
Box::new(Resolver(internal_resolver.clone())) as Box<dyn SymbolResolver + Send + Sync>
|
||||||
);
|
);
|
||||||
@ -977,9 +993,9 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
|||||||
let def = &*def.read();
|
let def = &*def.read();
|
||||||
|
|
||||||
if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {
|
if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {
|
||||||
|
println!("=========`{}`: number of instances: {}===========", name, instance_to_stmt.len());
|
||||||
for inst in instance_to_stmt.iter() {
|
for inst in instance_to_stmt.iter() {
|
||||||
let ast = &inst.1.body;
|
let ast = &inst.1.body;
|
||||||
println!("{}:", name);
|
|
||||||
for b in ast {
|
for b in ast {
|
||||||
println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap());
|
println!("{:?}", stringify_folder.fold_stmt(b.clone()).unwrap());
|
||||||
println!("--------------------");
|
println!("--------------------");
|
||||||
@ -991,6 +1007,31 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn make_internal_resolver_with_tvar(tvars: Vec<(String, Vec<Type>)>, unifier: &mut Unifier, print: bool) -> Arc<ResolverInternal> {
|
||||||
|
let res: Arc<ResolverInternal> = ResolverInternal {
|
||||||
|
id_to_def: Default::default(),
|
||||||
|
id_to_type: tvars
|
||||||
|
.into_iter()
|
||||||
|
.map(|(name, range)| (
|
||||||
|
name.clone(),
|
||||||
|
{
|
||||||
|
let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice());
|
||||||
|
if print {
|
||||||
|
println!("{}: {:?}, tvar{}", name, ty, id);
|
||||||
|
}
|
||||||
|
ty
|
||||||
|
}
|
||||||
|
))
|
||||||
|
.collect::<HashMap<_, _>>()
|
||||||
|
.into(),
|
||||||
|
class_names: Default::default()
|
||||||
|
}.into();
|
||||||
|
if print {
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
res
|
||||||
|
}
|
||||||
|
|
||||||
struct TypeToStringFolder<'a> {
|
struct TypeToStringFolder<'a> {
|
||||||
unifier: &'a mut Unifier
|
unifier: &'a mut Unifier
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user