1
0
forked from M-Labs/nac3

nac3core: fix recursive top level function call

This commit is contained in:
ychenfo 2021-09-14 22:49:20 +08:00
parent 526c18bda0
commit a0662c58e6
2 changed files with 99 additions and 64 deletions

View File

@ -77,7 +77,10 @@ impl TopLevelComposer {
) )
.into(), .into(),
// FIXME: all the big unifier or? // FIXME: all the big unifier or?
unifiers: Arc::new(RwLock::new(vec![(self.unifier.get_shared_unifier(), self.primitives_ty)])), unifiers: Arc::new(RwLock::new(vec![(
self.unifier.get_shared_unifier(),
self.primitives_ty,
)])),
} }
} }
@ -92,7 +95,7 @@ impl TopLevelComposer {
ast: ast::Stmt<()>, ast: ast::Stmt<()>,
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>, resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
mod_path: String, mod_path: String,
) -> Result<(String, DefinitionId), String> { ) -> Result<(String, DefinitionId, Option<Type>), String> {
let defined_class_name = &mut self.defined_class_name; let defined_class_name = &mut self.defined_class_name;
let defined_class_method_name = &mut self.defined_class_method_name; let defined_class_method_name = &mut self.defined_class_method_name;
let defined_function_name = &mut self.defined_function_name; let defined_function_name = &mut self.defined_function_name;
@ -212,7 +215,7 @@ impl TopLevelComposer {
None, None,
)); ));
Ok((class_name, DefinitionId(class_def_id))) Ok((class_name, DefinitionId(class_def_id), None))
} }
ast::StmtKind::FunctionDef { name, .. } => { ast::StmtKind::FunctionDef { name, .. } => {
@ -228,12 +231,13 @@ impl TopLevelComposer {
return Err("duplicate top level function define".into()); return Err("duplicate top level function define".into());
} }
let ty_to_be_unified = self.unifier.get_fresh_var().0;
// add to the definition list // add to the definition list
self.definition_ast_list.push(( self.definition_ast_list.push((
RwLock::new(Self::make_top_level_function_def( RwLock::new(Self::make_top_level_function_def(
name.into(), name.into(),
// dummy here, unify with correct type later // dummy here, unify with correct type later
self.unifier.get_fresh_var().0, ty_to_be_unified,
resolver, resolver,
)) ))
.into(), .into(),
@ -241,7 +245,11 @@ impl TopLevelComposer {
)); ));
// return // return
Ok((fun_name, DefinitionId(self.definition_ast_list.len() - 1))) Ok((
fun_name,
DefinitionId(self.definition_ast_list.len() - 1),
Some(ty_to_be_unified),
))
} }
_ => Err("only registrations of top level classes/functions are supprted".into()), _ => Err("only registrations of top level classes/functions are supprted".into()),
@ -1111,7 +1119,6 @@ impl TopLevelComposer {
/// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function /// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function
fn analyze_function_instance(&mut self) -> Result<(), String> { fn analyze_function_instance(&mut self) -> Result<(), String> {
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate() { for (id, (def, ast)) in self.definition_ast_list.iter().enumerate() {
let mut function_def = def.write(); let mut function_def = def.write();
if let TopLevelDef::Function { if let TopLevelDef::Function {
instance_to_stmt, instance_to_stmt,
@ -1120,7 +1127,8 @@ impl TopLevelComposer {
var_id, var_id,
resolver, resolver,
.. ..
} = &mut *function_def { } = &mut *function_def
{
if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() { if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() {
let FunSignature { args, ret, vars } = &*func_sig.borrow(); let FunSignature { args, ret, vars } = &*func_sig.borrow();
// None if is not class method // None if is not class method
@ -1134,7 +1142,7 @@ impl TopLevelComposer {
self.extract_def_list().as_slice(), self.extract_def_list().as_slice(),
&mut self.unifier, &mut self.unifier,
&self.primitives_ty, &self.primitives_ty,
&ty_ann &ty_ann,
)?) )?)
} else { } else {
unreachable!("must be class def") unreachable!("must be class def")
@ -1145,12 +1153,12 @@ impl TopLevelComposer {
}; };
let type_var_subst_comb = { let type_var_subst_comb = {
let unifier = &mut self.unifier; let unifier = &mut self.unifier;
let var_ids = vars let var_ids = vars.iter().map(|(id, _)| *id);
.iter()
.map(|(id, _)| *id);
let var_combs = vars let var_combs = vars
.iter() .iter()
.map(|(_, ty)| unifier.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) .map(|(_, ty)| {
unifier.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])
})
.multi_cartesian_product() .multi_cartesian_product()
.collect_vec(); .collect_vec();
let mut result: Vec<HashMap<u32, Type>> = Default::default(); let mut result: Vec<HashMap<u32, Type>> = Default::default();
@ -1173,16 +1181,16 @@ impl TopLevelComposer {
.map(|a| FuncArg { .map(|a| FuncArg {
name: a.name.clone(), name: a.name.clone(),
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
default_value: a.default_value.clone() default_value: a.default_value.clone(),
}) })
.collect_vec(); .collect_vec();
let self_type = self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x)); let self_type = self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x));
let mut identifiers = { let mut identifiers = {
// NOTE: none and function args? // NOTE: none and function args?
let mut result: HashSet<String> = HashSet::new(); let mut result: HashSet<String> = HashSet::new();
result.insert("None".into()); result.insert("None".into());
if self_type.is_some(){ if self_type.is_some() {
result.insert("self".into()); result.insert("self".into());
} }
result.extend(inst_args.iter().map(|x| x.name.clone())); result.extend(inst_args.iter().map(|x| x.name.clone()));
@ -1194,7 +1202,10 @@ impl TopLevelComposer {
defined_identifiers: identifiers.clone(), defined_identifiers: identifiers.clone(),
function_data: &mut FunctionData { function_data: &mut FunctionData {
resolver: resolver.as_ref().unwrap().clone(), resolver: resolver.as_ref().unwrap().clone(),
return_type: if self.unifier.unioned(inst_ret, self.primitives_ty.none) { return_type: if self
.unifier
.unioned(inst_ret, self.primitives_ty.none)
{
None None
} else { } else {
Some(inst_ret) Some(inst_ret)
@ -1219,28 +1230,29 @@ impl TopLevelComposer {
} }
}; };
let fun_body = if let ast::StmtKind::FunctionDef { body, .. } = ast.clone().unwrap().node { let fun_body = if let ast::StmtKind::FunctionDef { body, .. } =
body ast.clone().unwrap().node
} else { {
unreachable!("must be function def ast") body
} } else {
.into_iter() unreachable!("must be function def ast")
.map(|b| inferencer.fold_stmt(b)) }
.collect::<Result<Vec<_>, _>>()?; .into_iter()
.map(|b| inferencer.fold_stmt(b))
let returned = inferencer .collect::<Result<Vec<_>, _>>()?;
.check_block(fun_body.as_slice(), &mut identifiers)?;
let returned =
inferencer.check_block(fun_body.as_slice(), &mut identifiers)?;
if !self.unifier.unioned(inst_ret, self.primitives_ty.none) && !returned { if !self.unifier.unioned(inst_ret, self.primitives_ty.none) && !returned {
let ret_str = self.unifier.stringify( let ret_str = self.unifier.stringify(
inst_ret, inst_ret,
&mut |id| format!("class{}", id), &mut |id| format!("class{}", id),
&mut |id| format!("tvar{}", id) &mut |id| format!("tvar{}", id),
); );
return Err(format!( return Err(format!(
"expected return type of {} in function `{}`", "expected return type of {} in function `{}`",
ret_str, ret_str, name
name
)); ));
} }
@ -1251,17 +1263,16 @@ impl TopLevelComposer {
body: fun_body, body: fun_body,
unifier_id: 0, unifier_id: 0,
calls: HashMap::new(), calls: HashMap::new(),
subst subst,
} },
); );
} }
} else { } else {
unreachable!("must be typeenum::tfunc") unreachable!("must be typeenum::tfunc")
} }
} else { } else {
continue continue;
} }
} }
Ok(()) Ok(())
} }

View File

@ -25,12 +25,17 @@ impl ResolverInternal {
fn add_id_def(&self, id: String, def: DefinitionId) { fn add_id_def(&self, id: String, def: DefinitionId) {
self.id_to_def.lock().insert(id, def); self.id_to_def.lock().insert(id, def);
} }
fn add_id_type(&self, id: String, ty: Type) {
self.id_to_type.lock().insert(id, ty);
}
} }
struct Resolver(Arc<ResolverInternal>); struct Resolver(Arc<ResolverInternal>);
impl SymbolResolver for Resolver { impl SymbolResolver for Resolver {
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> { fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
println!("unkonw here resolver {}", str);
self.0.id_to_type.lock().get(str).cloned() self.0.id_to_type.lock().get(str).cloned()
} }
@ -133,9 +138,12 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
let ast = parse_program(s).unwrap(); let ast = parse_program(s).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
let (id, def_id) = let (id, def_id, ty) =
composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()).unwrap(); composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()).unwrap();
internal_resolver.add_id_def(id, def_id); internal_resolver.add_id_def(id.clone(), def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
}
} }
composer.start_analysis(true).unwrap(); composer.start_analysis(true).unwrap();
@ -786,7 +794,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
let ast = parse_program(s).unwrap(); let ast = parse_program(s).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
let (id, def_id) = { let (id, def_id, ty) = {
match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) { match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) {
Ok(x) => x, Ok(x) => x,
Err(msg) => { Err(msg) => {
@ -799,7 +807,10 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
} }
} }
}; };
internal_resolver.add_id_def(id, def_id); internal_resolver.add_id_def(id.clone(), def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
}
} }
if let Err(msg) = composer.start_analysis(false) { if let Err(msg) = composer.start_analysis(false) {
@ -846,6 +857,14 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
indoc! {" indoc! {"
def fun(a: int32, b: int32) -> int32: def fun(a: int32, b: int32) -> int32:
return a + b return a + b
"},
indoc! {"
def fib(n: int32) -> int32:
if n <= 2:
return 1
a = fib(n - 1)
b = fib(n - 2)
return fib(n - 1)
"} "}
], ],
vec![]; vec![];
@ -861,14 +880,24 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
def fun(self) -> int32: def fun(self) -> int32:
b = self.a + 3 b = self.a + 3
return b * self.a return b * self.a
def dup(self) -> A: def clone(self) -> A:
SELF = self SELF = self
return SELF return SELF
def sum(self) -> int32:
if self.a == 0:
return self.a
else:
a = self.a
self.a = self.a - 1
return a + self.sum()
def fib(self, a: int32) -> int32:
if a <= 2:
return 1
return self.fib(a - 1) + self.fib(a - 2)
"}, "},
indoc! {" indoc! {"
def fun(a: A) -> int32: def fun(a: A) -> int32:
return a.fun() return a.fun() + 2
"} "}
], ],
vec![]; vec![];
@ -878,21 +907,9 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
let print = true; let print = true;
let mut composer = TopLevelComposer::new(); let mut composer = TopLevelComposer::new();
let tvar_t = composer.unifier.get_fresh_var();
let tvar_v = composer
.unifier
.get_fresh_var_with_range(&[composer.primitives_ty.bool, composer.primitives_ty.int32]);
if print {
println!("t: {}, {:?}", tvar_t.1, tvar_t.0);
println!("v: {}, {:?}\n", tvar_v.1, tvar_v.0);
}
let internal_resolver = Arc::new(ResolverInternal { let internal_resolver = Arc::new(ResolverInternal {
id_to_def: Default::default(), id_to_def: Default::default(),
id_to_type: Mutex::new( id_to_type: Default::default(),
vec![("T".to_string(), tvar_t.0), ("V".to_string(), tvar_v.0)].into_iter().collect(),
),
class_names: Default::default(), class_names: Default::default(),
}); });
let resolver = Arc::new( let resolver = Arc::new(
@ -903,7 +920,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
let ast = parse_program(s).unwrap(); let ast = parse_program(s).unwrap();
let ast = ast[0].clone(); let ast = ast[0].clone();
let (id, def_id) = { let (id, def_id, ty) = {
match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) { match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) {
Ok(x) => x, Ok(x) => x,
Err(msg) => { Err(msg) => {
@ -916,12 +933,14 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
} }
} }
}; };
internal_resolver.add_id_def(id, def_id); internal_resolver.add_id_def(id.clone(), def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(id, ty);
}
} }
if let Err(msg) = composer.start_analysis(true) { if let Err(msg) = composer.start_analysis(true) {
if print { if print {
// println!("err2:");
println!("{}", msg); println!("{}", msg);
} else { } else {
assert_eq!(res[0], msg); assert_eq!(res[0], msg);
@ -931,12 +950,17 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() { for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() {
let def = &*def.read(); let def = &*def.read();
if let TopLevelDef::Function { instance_to_stmt, .. } = def { if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {
for inst in instance_to_stmt.iter() { for inst in instance_to_stmt.iter() {
let ast = &inst.1.body; let ast = &inst.1.body;
println!("{:?}", ast) println!("{}:", name);
for b in ast {
println!("{:?}", b);
println!("--------------------");
}
println!("\n");
} }
} }
} }
} }
} }