forked from M-Labs/nac3
nac3core: top level inferencer without type var should be ok
This commit is contained in:
parent
a10ab81ee7
commit
526c18bda0
|
@ -1,3 +1,7 @@
|
||||||
|
use rustpython_parser::ast::fold::Fold;
|
||||||
|
|
||||||
|
use crate::typecheck::type_inferencer::{FunctionData, Inferencer};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
type DefAst = (Arc<RwLock<TopLevelDef>>, Option<ast::Stmt<()>>);
|
type DefAst = (Arc<RwLock<TopLevelDef>>, Option<ast::Stmt<()>>);
|
||||||
|
@ -14,6 +18,8 @@ pub struct TopLevelComposer {
|
||||||
pub defined_class_name: HashSet<String>,
|
pub defined_class_name: HashSet<String>,
|
||||||
pub defined_class_method_name: HashSet<String>,
|
pub defined_class_method_name: HashSet<String>,
|
||||||
pub defined_function_name: HashSet<String>,
|
pub defined_function_name: HashSet<String>,
|
||||||
|
// get the class def id of a class method
|
||||||
|
pub method_class: HashMap<DefinitionId, DefinitionId>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for TopLevelComposer {
|
impl Default for TopLevelComposer {
|
||||||
|
@ -60,13 +66,14 @@ impl TopLevelComposer {
|
||||||
defined_class_method_name: Default::default(),
|
defined_class_method_name: Default::default(),
|
||||||
defined_class_name: Default::default(),
|
defined_class_name: Default::default(),
|
||||||
defined_function_name: Default::default(),
|
defined_function_name: Default::default(),
|
||||||
|
method_class: Default::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn make_top_level_context(self) -> TopLevelContext {
|
pub fn make_top_level_context(&self) -> TopLevelContext {
|
||||||
TopLevelContext {
|
TopLevelContext {
|
||||||
definitions: RwLock::new(
|
definitions: RwLock::new(
|
||||||
self.definition_ast_list.into_iter().map(|(x, ..)| x).collect_vec(),
|
self.definition_ast_list.iter().map(|(x, ..)| x.clone()).collect_vec(),
|
||||||
)
|
)
|
||||||
.into(),
|
.into(),
|
||||||
// FIXME: all the big unifier or?
|
// FIXME: all the big unifier or?
|
||||||
|
@ -186,7 +193,8 @@ impl TopLevelComposer {
|
||||||
for (name, _, id, ty, ..) in &class_method_name_def_ids {
|
for (name, _, id, ty, ..) in &class_method_name_def_ids {
|
||||||
let mut class_def = class_def_ast.0.write();
|
let mut class_def = class_def_ast.0.write();
|
||||||
if let TopLevelDef::Class { methods, .. } = class_def.deref_mut() {
|
if let TopLevelDef::Class { methods, .. } = class_def.deref_mut() {
|
||||||
methods.push((name.clone(), *ty, *id))
|
methods.push((name.clone(), *ty, *id));
|
||||||
|
self.method_class.insert(*id, DefinitionId(class_def_id));
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
|
@ -240,11 +248,14 @@ impl TopLevelComposer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn start_analysis(&mut self) -> Result<(), String> {
|
pub fn start_analysis(&mut self, inference: bool) -> Result<(), String> {
|
||||||
self.analyze_top_level_class_type_var()?;
|
self.analyze_top_level_class_type_var()?;
|
||||||
self.analyze_top_level_class_bases()?;
|
self.analyze_top_level_class_bases()?;
|
||||||
self.analyze_top_level_class_fields_methods()?;
|
self.analyze_top_level_class_fields_methods()?;
|
||||||
self.analyze_top_level_function()?;
|
self.analyze_top_level_function()?;
|
||||||
|
if inference {
|
||||||
|
self.analyze_function_instance()?;
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1096,4 +1107,162 @@ impl TopLevelComposer {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function
|
||||||
|
fn analyze_function_instance(&mut self) -> Result<(), String> {
|
||||||
|
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate() {
|
||||||
|
|
||||||
|
let mut function_def = def.write();
|
||||||
|
if let TopLevelDef::Function {
|
||||||
|
instance_to_stmt,
|
||||||
|
name,
|
||||||
|
signature,
|
||||||
|
var_id,
|
||||||
|
resolver,
|
||||||
|
..
|
||||||
|
} = &mut *function_def {
|
||||||
|
if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() {
|
||||||
|
let FunSignature { args, ret, vars } = &*func_sig.borrow();
|
||||||
|
// None if is not class method
|
||||||
|
let self_type = {
|
||||||
|
if let Some(class_id) = self.method_class.get(&DefinitionId(id)) {
|
||||||
|
let class_def = self.definition_ast_list.get(class_id.0).unwrap();
|
||||||
|
let class_def = class_def.0.read();
|
||||||
|
if let TopLevelDef::Class { type_vars, .. } = &*class_def {
|
||||||
|
let ty_ann = make_self_type_annotation(type_vars, *class_id);
|
||||||
|
Some(get_type_from_type_annotation_kinds(
|
||||||
|
self.extract_def_list().as_slice(),
|
||||||
|
&mut self.unifier,
|
||||||
|
&self.primitives_ty,
|
||||||
|
&ty_ann
|
||||||
|
)?)
|
||||||
|
} else {
|
||||||
|
unreachable!("must be class def")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let type_var_subst_comb = {
|
||||||
|
let unifier = &mut self.unifier;
|
||||||
|
let var_ids = vars
|
||||||
|
.iter()
|
||||||
|
.map(|(id, _)| *id);
|
||||||
|
let var_combs = vars
|
||||||
|
.iter()
|
||||||
|
.map(|(_, ty)| unifier.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
|
||||||
|
.multi_cartesian_product()
|
||||||
|
.collect_vec();
|
||||||
|
let mut result: Vec<HashMap<u32, Type>> = Default::default();
|
||||||
|
for comb in var_combs {
|
||||||
|
result.push(var_ids.clone().zip(comb).collect());
|
||||||
|
}
|
||||||
|
// NOTE: if is empty, means no type var, append a empty subst, ok to do this?
|
||||||
|
if result.is_empty() {
|
||||||
|
result.push(HashMap::new())
|
||||||
|
}
|
||||||
|
result
|
||||||
|
};
|
||||||
|
|
||||||
|
for subst in type_var_subst_comb {
|
||||||
|
// for each instance
|
||||||
|
let unifier = &mut self.unifier;
|
||||||
|
let inst_ret = unifier.subst(*ret, &subst).unwrap_or(*ret);
|
||||||
|
let inst_args = args
|
||||||
|
.iter()
|
||||||
|
.map(|a| FuncArg {
|
||||||
|
name: a.name.clone(),
|
||||||
|
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
|
||||||
|
default_value: a.default_value.clone()
|
||||||
|
})
|
||||||
|
.collect_vec();
|
||||||
|
let self_type = self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x));
|
||||||
|
|
||||||
|
let mut identifiers = {
|
||||||
|
// NOTE: none and function args?
|
||||||
|
let mut result: HashSet<String> = HashSet::new();
|
||||||
|
result.insert("None".into());
|
||||||
|
if self_type.is_some(){
|
||||||
|
result.insert("self".into());
|
||||||
|
}
|
||||||
|
result.extend(inst_args.iter().map(|x| x.name.clone()));
|
||||||
|
result
|
||||||
|
};
|
||||||
|
let mut inferencer = {
|
||||||
|
Inferencer {
|
||||||
|
top_level: &self.make_top_level_context(),
|
||||||
|
defined_identifiers: identifiers.clone(),
|
||||||
|
function_data: &mut FunctionData {
|
||||||
|
resolver: resolver.as_ref().unwrap().clone(),
|
||||||
|
return_type: if self.unifier.unioned(inst_ret, self.primitives_ty.none) {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(inst_ret)
|
||||||
|
},
|
||||||
|
// NOTE: allowed type vars: leave blank?
|
||||||
|
bound_variables: Vec::new(),
|
||||||
|
},
|
||||||
|
unifier: &mut self.unifier,
|
||||||
|
variable_mapping: {
|
||||||
|
// NOTE: none and function args?
|
||||||
|
let mut result: HashMap<String, Type> = HashMap::new();
|
||||||
|
result.insert("None".into(), self.primitives_ty.none);
|
||||||
|
if let Some(self_ty) = self_type {
|
||||||
|
result.insert("self".into(), self_ty);
|
||||||
|
}
|
||||||
|
result.extend(inst_args.iter().map(|x| (x.name.clone(), x.ty)));
|
||||||
|
result
|
||||||
|
},
|
||||||
|
primitives: &self.primitives_ty,
|
||||||
|
virtual_checks: &mut Vec::new(),
|
||||||
|
calls: &mut HashMap::new(),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let fun_body = if let ast::StmtKind::FunctionDef { body, .. } = ast.clone().unwrap().node {
|
||||||
|
body
|
||||||
|
} else {
|
||||||
|
unreachable!("must be function def ast")
|
||||||
|
}
|
||||||
|
.into_iter()
|
||||||
|
.map(|b| inferencer.fold_stmt(b))
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
|
let returned = inferencer
|
||||||
|
.check_block(fun_body.as_slice(), &mut identifiers)?;
|
||||||
|
|
||||||
|
if !self.unifier.unioned(inst_ret, self.primitives_ty.none) && !returned {
|
||||||
|
let ret_str = self.unifier.stringify(
|
||||||
|
inst_ret,
|
||||||
|
&mut |id| format!("class{}", id),
|
||||||
|
&mut |id| format!("tvar{}", id)
|
||||||
|
);
|
||||||
|
return Err(format!(
|
||||||
|
"expected return type of {} in function `{}`",
|
||||||
|
ret_str,
|
||||||
|
name
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
instance_to_stmt.insert(
|
||||||
|
// FIXME: how?
|
||||||
|
"".to_string(),
|
||||||
|
FunInstance {
|
||||||
|
body: fun_body,
|
||||||
|
unifier_id: 0,
|
||||||
|
calls: HashMap::new(),
|
||||||
|
subst
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unreachable!("must be typeenum::tfunc")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ pub struct FunInstance {
|
||||||
pub unifier_id: usize,
|
pub unifier_id: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum TopLevelDef {
|
pub enum TopLevelDef {
|
||||||
Class {
|
Class {
|
||||||
// name for error messages and symbols
|
// name for error messages and symbols
|
||||||
|
|
|
@ -138,7 +138,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
|
||||||
internal_resolver.add_id_def(id, def_id);
|
internal_resolver.add_id_def(id, def_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
composer.start_analysis().unwrap();
|
composer.start_analysis(true).unwrap();
|
||||||
|
|
||||||
for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() {
|
for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() {
|
||||||
let def = &*def.read();
|
let def = &*def.read();
|
||||||
|
@ -802,7 +802,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||||
internal_resolver.add_id_def(id, def_id);
|
internal_resolver.add_id_def(id, def_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(msg) = composer.start_analysis() {
|
if let Err(msg) = composer.start_analysis(false) {
|
||||||
if print {
|
if print {
|
||||||
println!("{}", msg);
|
println!("{}", msg);
|
||||||
} else {
|
} else {
|
||||||
|
@ -840,3 +840,103 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test_case(
|
||||||
|
vec![
|
||||||
|
indoc! {"
|
||||||
|
def fun(a: int32, b: int32) -> int32:
|
||||||
|
return a + b
|
||||||
|
"}
|
||||||
|
],
|
||||||
|
vec![];
|
||||||
|
"simple function"
|
||||||
|
)]
|
||||||
|
#[test_case(
|
||||||
|
vec![
|
||||||
|
indoc! {"
|
||||||
|
class A:
|
||||||
|
a: int32
|
||||||
|
def __init__(self):
|
||||||
|
self.a = 3
|
||||||
|
def fun(self) -> int32:
|
||||||
|
b = self.a + 3
|
||||||
|
return b * self.a
|
||||||
|
def dup(self) -> A:
|
||||||
|
SELF = self
|
||||||
|
return SELF
|
||||||
|
|
||||||
|
"},
|
||||||
|
indoc! {"
|
||||||
|
def fun(a: A) -> int32:
|
||||||
|
return a.fun()
|
||||||
|
"}
|
||||||
|
],
|
||||||
|
vec![];
|
||||||
|
"simple class body"
|
||||||
|
)]
|
||||||
|
fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||||
|
let print = true;
|
||||||
|
let mut composer = TopLevelComposer::new();
|
||||||
|
|
||||||
|
let tvar_t = composer.unifier.get_fresh_var();
|
||||||
|
let tvar_v = composer
|
||||||
|
.unifier
|
||||||
|
.get_fresh_var_with_range(&[composer.primitives_ty.bool, composer.primitives_ty.int32]);
|
||||||
|
|
||||||
|
if print {
|
||||||
|
println!("t: {}, {:?}", tvar_t.1, tvar_t.0);
|
||||||
|
println!("v: {}, {:?}\n", tvar_v.1, tvar_v.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let internal_resolver = Arc::new(ResolverInternal {
|
||||||
|
id_to_def: Default::default(),
|
||||||
|
id_to_type: Mutex::new(
|
||||||
|
vec![("T".to_string(), tvar_t.0), ("V".to_string(), tvar_v.0)].into_iter().collect(),
|
||||||
|
),
|
||||||
|
class_names: Default::default(),
|
||||||
|
});
|
||||||
|
let resolver = Arc::new(
|
||||||
|
Box::new(Resolver(internal_resolver.clone())) as Box<dyn SymbolResolver + Send + Sync>
|
||||||
|
);
|
||||||
|
|
||||||
|
for s in source {
|
||||||
|
let ast = parse_program(s).unwrap();
|
||||||
|
let ast = ast[0].clone();
|
||||||
|
|
||||||
|
let (id, def_id) = {
|
||||||
|
match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) {
|
||||||
|
Ok(x) => x,
|
||||||
|
Err(msg) => {
|
||||||
|
if print {
|
||||||
|
println!("{}", msg);
|
||||||
|
} else {
|
||||||
|
assert_eq!(res[0], msg);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
internal_resolver.add_id_def(id, def_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(msg) = composer.start_analysis(true) {
|
||||||
|
if print {
|
||||||
|
// println!("err2:");
|
||||||
|
println!("{}", msg);
|
||||||
|
} else {
|
||||||
|
assert_eq!(res[0], msg);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// skip 5 to skip primitives
|
||||||
|
for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() {
|
||||||
|
let def = &*def.read();
|
||||||
|
|
||||||
|
if let TopLevelDef::Function { instance_to_stmt, .. } = def {
|
||||||
|
for inst in instance_to_stmt.iter() {
|
||||||
|
let ast = &inst.1.body;
|
||||||
|
println!("{:?}", ast)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -237,7 +237,7 @@ pub fn get_type_from_type_annotation_kinds(
|
||||||
|
|
||||||
let subst = {
|
let subst = {
|
||||||
// check for compatible range
|
// check for compatible range
|
||||||
// TODO: if allow type var to be applied, need more check
|
// TODO: if allow type var to be applied(now this disallowed in the parse_to_type_annotation), need more check
|
||||||
let mut result: HashMap<u32, Type> = HashMap::new();
|
let mut result: HashMap<u32, Type> = HashMap::new();
|
||||||
for (tvar, p) in type_vars.iter().zip(param_ty) {
|
for (tvar, p) in type_vars.iter().zip(param_ty) {
|
||||||
if let TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic } =
|
if let TypeEnum::TVar { id, range, meta: TypeVarMeta::Generic } =
|
||||||
|
|
Loading…
Reference in New Issue