forked from M-Labs/nac3
nac3core: fix recursive top level function call
This commit is contained in:
parent
526c18bda0
commit
a0662c58e6
|
@ -77,7 +77,10 @@ impl TopLevelComposer {
|
||||||
)
|
)
|
||||||
.into(),
|
.into(),
|
||||||
// FIXME: all the big unifier or?
|
// FIXME: all the big unifier or?
|
||||||
unifiers: Arc::new(RwLock::new(vec![(self.unifier.get_shared_unifier(), self.primitives_ty)])),
|
unifiers: Arc::new(RwLock::new(vec![(
|
||||||
|
self.unifier.get_shared_unifier(),
|
||||||
|
self.primitives_ty,
|
||||||
|
)])),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,7 +95,7 @@ impl TopLevelComposer {
|
||||||
ast: ast::Stmt<()>,
|
ast: ast::Stmt<()>,
|
||||||
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
|
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
|
||||||
mod_path: String,
|
mod_path: String,
|
||||||
) -> Result<(String, DefinitionId), String> {
|
) -> Result<(String, DefinitionId, Option<Type>), String> {
|
||||||
let defined_class_name = &mut self.defined_class_name;
|
let defined_class_name = &mut self.defined_class_name;
|
||||||
let defined_class_method_name = &mut self.defined_class_method_name;
|
let defined_class_method_name = &mut self.defined_class_method_name;
|
||||||
let defined_function_name = &mut self.defined_function_name;
|
let defined_function_name = &mut self.defined_function_name;
|
||||||
|
@ -212,7 +215,7 @@ impl TopLevelComposer {
|
||||||
None,
|
None,
|
||||||
));
|
));
|
||||||
|
|
||||||
Ok((class_name, DefinitionId(class_def_id)))
|
Ok((class_name, DefinitionId(class_def_id), None))
|
||||||
}
|
}
|
||||||
|
|
||||||
ast::StmtKind::FunctionDef { name, .. } => {
|
ast::StmtKind::FunctionDef { name, .. } => {
|
||||||
|
@ -228,12 +231,13 @@ impl TopLevelComposer {
|
||||||
return Err("duplicate top level function define".into());
|
return Err("duplicate top level function define".into());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let ty_to_be_unified = self.unifier.get_fresh_var().0;
|
||||||
// add to the definition list
|
// add to the definition list
|
||||||
self.definition_ast_list.push((
|
self.definition_ast_list.push((
|
||||||
RwLock::new(Self::make_top_level_function_def(
|
RwLock::new(Self::make_top_level_function_def(
|
||||||
name.into(),
|
name.into(),
|
||||||
// dummy here, unify with correct type later
|
// dummy here, unify with correct type later
|
||||||
self.unifier.get_fresh_var().0,
|
ty_to_be_unified,
|
||||||
resolver,
|
resolver,
|
||||||
))
|
))
|
||||||
.into(),
|
.into(),
|
||||||
|
@ -241,7 +245,11 @@ impl TopLevelComposer {
|
||||||
));
|
));
|
||||||
|
|
||||||
// return
|
// return
|
||||||
Ok((fun_name, DefinitionId(self.definition_ast_list.len() - 1)))
|
Ok((
|
||||||
|
fun_name,
|
||||||
|
DefinitionId(self.definition_ast_list.len() - 1),
|
||||||
|
Some(ty_to_be_unified),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => Err("only registrations of top level classes/functions are supprted".into()),
|
_ => Err("only registrations of top level classes/functions are supprted".into()),
|
||||||
|
@ -1111,7 +1119,6 @@ impl TopLevelComposer {
|
||||||
/// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function
|
/// step 5, analyze and call type inferecer to fill the `instance_to_stmt` of topleveldef::function
|
||||||
fn analyze_function_instance(&mut self) -> Result<(), String> {
|
fn analyze_function_instance(&mut self) -> Result<(), String> {
|
||||||
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate() {
|
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate() {
|
||||||
|
|
||||||
let mut function_def = def.write();
|
let mut function_def = def.write();
|
||||||
if let TopLevelDef::Function {
|
if let TopLevelDef::Function {
|
||||||
instance_to_stmt,
|
instance_to_stmt,
|
||||||
|
@ -1120,7 +1127,8 @@ impl TopLevelComposer {
|
||||||
var_id,
|
var_id,
|
||||||
resolver,
|
resolver,
|
||||||
..
|
..
|
||||||
} = &mut *function_def {
|
} = &mut *function_def
|
||||||
|
{
|
||||||
if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() {
|
if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() {
|
||||||
let FunSignature { args, ret, vars } = &*func_sig.borrow();
|
let FunSignature { args, ret, vars } = &*func_sig.borrow();
|
||||||
// None if is not class method
|
// None if is not class method
|
||||||
|
@ -1134,7 +1142,7 @@ impl TopLevelComposer {
|
||||||
self.extract_def_list().as_slice(),
|
self.extract_def_list().as_slice(),
|
||||||
&mut self.unifier,
|
&mut self.unifier,
|
||||||
&self.primitives_ty,
|
&self.primitives_ty,
|
||||||
&ty_ann
|
&ty_ann,
|
||||||
)?)
|
)?)
|
||||||
} else {
|
} else {
|
||||||
unreachable!("must be class def")
|
unreachable!("must be class def")
|
||||||
|
@ -1145,12 +1153,12 @@ impl TopLevelComposer {
|
||||||
};
|
};
|
||||||
let type_var_subst_comb = {
|
let type_var_subst_comb = {
|
||||||
let unifier = &mut self.unifier;
|
let unifier = &mut self.unifier;
|
||||||
let var_ids = vars
|
let var_ids = vars.iter().map(|(id, _)| *id);
|
||||||
.iter()
|
|
||||||
.map(|(id, _)| *id);
|
|
||||||
let var_combs = vars
|
let var_combs = vars
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(_, ty)| unifier.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
|
.map(|(_, ty)| {
|
||||||
|
unifier.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])
|
||||||
|
})
|
||||||
.multi_cartesian_product()
|
.multi_cartesian_product()
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
let mut result: Vec<HashMap<u32, Type>> = Default::default();
|
let mut result: Vec<HashMap<u32, Type>> = Default::default();
|
||||||
|
@ -1173,7 +1181,7 @@ impl TopLevelComposer {
|
||||||
.map(|a| FuncArg {
|
.map(|a| FuncArg {
|
||||||
name: a.name.clone(),
|
name: a.name.clone(),
|
||||||
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
|
ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty),
|
||||||
default_value: a.default_value.clone()
|
default_value: a.default_value.clone(),
|
||||||
})
|
})
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
let self_type = self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x));
|
let self_type = self_type.map(|x| unifier.subst(x, &subst).unwrap_or(x));
|
||||||
|
@ -1182,7 +1190,7 @@ impl TopLevelComposer {
|
||||||
// NOTE: none and function args?
|
// NOTE: none and function args?
|
||||||
let mut result: HashSet<String> = HashSet::new();
|
let mut result: HashSet<String> = HashSet::new();
|
||||||
result.insert("None".into());
|
result.insert("None".into());
|
||||||
if self_type.is_some(){
|
if self_type.is_some() {
|
||||||
result.insert("self".into());
|
result.insert("self".into());
|
||||||
}
|
}
|
||||||
result.extend(inst_args.iter().map(|x| x.name.clone()));
|
result.extend(inst_args.iter().map(|x| x.name.clone()));
|
||||||
|
@ -1194,7 +1202,10 @@ impl TopLevelComposer {
|
||||||
defined_identifiers: identifiers.clone(),
|
defined_identifiers: identifiers.clone(),
|
||||||
function_data: &mut FunctionData {
|
function_data: &mut FunctionData {
|
||||||
resolver: resolver.as_ref().unwrap().clone(),
|
resolver: resolver.as_ref().unwrap().clone(),
|
||||||
return_type: if self.unifier.unioned(inst_ret, self.primitives_ty.none) {
|
return_type: if self
|
||||||
|
.unifier
|
||||||
|
.unioned(inst_ret, self.primitives_ty.none)
|
||||||
|
{
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(inst_ret)
|
Some(inst_ret)
|
||||||
|
@ -1219,7 +1230,9 @@ impl TopLevelComposer {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let fun_body = if let ast::StmtKind::FunctionDef { body, .. } = ast.clone().unwrap().node {
|
let fun_body = if let ast::StmtKind::FunctionDef { body, .. } =
|
||||||
|
ast.clone().unwrap().node
|
||||||
|
{
|
||||||
body
|
body
|
||||||
} else {
|
} else {
|
||||||
unreachable!("must be function def ast")
|
unreachable!("must be function def ast")
|
||||||
|
@ -1228,19 +1241,18 @@ impl TopLevelComposer {
|
||||||
.map(|b| inferencer.fold_stmt(b))
|
.map(|b| inferencer.fold_stmt(b))
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
let returned = inferencer
|
let returned =
|
||||||
.check_block(fun_body.as_slice(), &mut identifiers)?;
|
inferencer.check_block(fun_body.as_slice(), &mut identifiers)?;
|
||||||
|
|
||||||
if !self.unifier.unioned(inst_ret, self.primitives_ty.none) && !returned {
|
if !self.unifier.unioned(inst_ret, self.primitives_ty.none) && !returned {
|
||||||
let ret_str = self.unifier.stringify(
|
let ret_str = self.unifier.stringify(
|
||||||
inst_ret,
|
inst_ret,
|
||||||
&mut |id| format!("class{}", id),
|
&mut |id| format!("class{}", id),
|
||||||
&mut |id| format!("tvar{}", id)
|
&mut |id| format!("tvar{}", id),
|
||||||
);
|
);
|
||||||
return Err(format!(
|
return Err(format!(
|
||||||
"expected return type of {} in function `{}`",
|
"expected return type of {} in function `{}`",
|
||||||
ret_str,
|
ret_str, name
|
||||||
name
|
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1251,17 +1263,16 @@ impl TopLevelComposer {
|
||||||
body: fun_body,
|
body: fun_body,
|
||||||
unifier_id: 0,
|
unifier_id: 0,
|
||||||
calls: HashMap::new(),
|
calls: HashMap::new(),
|
||||||
subst
|
subst,
|
||||||
}
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
unreachable!("must be typeenum::tfunc")
|
unreachable!("must be typeenum::tfunc")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
continue
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,12 +25,17 @@ impl ResolverInternal {
|
||||||
fn add_id_def(&self, id: String, def: DefinitionId) {
|
fn add_id_def(&self, id: String, def: DefinitionId) {
|
||||||
self.id_to_def.lock().insert(id, def);
|
self.id_to_def.lock().insert(id, def);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn add_id_type(&self, id: String, ty: Type) {
|
||||||
|
self.id_to_type.lock().insert(id, ty);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Resolver(Arc<ResolverInternal>);
|
struct Resolver(Arc<ResolverInternal>);
|
||||||
|
|
||||||
impl SymbolResolver for Resolver {
|
impl SymbolResolver for Resolver {
|
||||||
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
|
fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option<Type> {
|
||||||
|
println!("unkonw here resolver {}", str);
|
||||||
self.0.id_to_type.lock().get(str).cloned()
|
self.0.id_to_type.lock().get(str).cloned()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,9 +138,12 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
|
||||||
let ast = parse_program(s).unwrap();
|
let ast = parse_program(s).unwrap();
|
||||||
let ast = ast[0].clone();
|
let ast = ast[0].clone();
|
||||||
|
|
||||||
let (id, def_id) =
|
let (id, def_id, ty) =
|
||||||
composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()).unwrap();
|
composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()).unwrap();
|
||||||
internal_resolver.add_id_def(id, def_id);
|
internal_resolver.add_id_def(id.clone(), def_id);
|
||||||
|
if let Some(ty) = ty {
|
||||||
|
internal_resolver.add_id_type(id, ty);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
composer.start_analysis(true).unwrap();
|
composer.start_analysis(true).unwrap();
|
||||||
|
@ -786,7 +794,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||||
let ast = parse_program(s).unwrap();
|
let ast = parse_program(s).unwrap();
|
||||||
let ast = ast[0].clone();
|
let ast = ast[0].clone();
|
||||||
|
|
||||||
let (id, def_id) = {
|
let (id, def_id, ty) = {
|
||||||
match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) {
|
match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) {
|
||||||
Ok(x) => x,
|
Ok(x) => x,
|
||||||
Err(msg) => {
|
Err(msg) => {
|
||||||
|
@ -799,7 +807,10 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
internal_resolver.add_id_def(id, def_id);
|
internal_resolver.add_id_def(id.clone(), def_id);
|
||||||
|
if let Some(ty) = ty {
|
||||||
|
internal_resolver.add_id_type(id, ty);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(msg) = composer.start_analysis(false) {
|
if let Err(msg) = composer.start_analysis(false) {
|
||||||
|
@ -846,6 +857,14 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||||
indoc! {"
|
indoc! {"
|
||||||
def fun(a: int32, b: int32) -> int32:
|
def fun(a: int32, b: int32) -> int32:
|
||||||
return a + b
|
return a + b
|
||||||
|
"},
|
||||||
|
indoc! {"
|
||||||
|
def fib(n: int32) -> int32:
|
||||||
|
if n <= 2:
|
||||||
|
return 1
|
||||||
|
a = fib(n - 1)
|
||||||
|
b = fib(n - 2)
|
||||||
|
return fib(n - 1)
|
||||||
"}
|
"}
|
||||||
],
|
],
|
||||||
vec![];
|
vec![];
|
||||||
|
@ -861,14 +880,24 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||||
def fun(self) -> int32:
|
def fun(self) -> int32:
|
||||||
b = self.a + 3
|
b = self.a + 3
|
||||||
return b * self.a
|
return b * self.a
|
||||||
def dup(self) -> A:
|
def clone(self) -> A:
|
||||||
SELF = self
|
SELF = self
|
||||||
return SELF
|
return SELF
|
||||||
|
def sum(self) -> int32:
|
||||||
|
if self.a == 0:
|
||||||
|
return self.a
|
||||||
|
else:
|
||||||
|
a = self.a
|
||||||
|
self.a = self.a - 1
|
||||||
|
return a + self.sum()
|
||||||
|
def fib(self, a: int32) -> int32:
|
||||||
|
if a <= 2:
|
||||||
|
return 1
|
||||||
|
return self.fib(a - 1) + self.fib(a - 2)
|
||||||
"},
|
"},
|
||||||
indoc! {"
|
indoc! {"
|
||||||
def fun(a: A) -> int32:
|
def fun(a: A) -> int32:
|
||||||
return a.fun()
|
return a.fun() + 2
|
||||||
"}
|
"}
|
||||||
],
|
],
|
||||||
vec![];
|
vec![];
|
||||||
|
@ -878,21 +907,9 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||||
let print = true;
|
let print = true;
|
||||||
let mut composer = TopLevelComposer::new();
|
let mut composer = TopLevelComposer::new();
|
||||||
|
|
||||||
let tvar_t = composer.unifier.get_fresh_var();
|
|
||||||
let tvar_v = composer
|
|
||||||
.unifier
|
|
||||||
.get_fresh_var_with_range(&[composer.primitives_ty.bool, composer.primitives_ty.int32]);
|
|
||||||
|
|
||||||
if print {
|
|
||||||
println!("t: {}, {:?}", tvar_t.1, tvar_t.0);
|
|
||||||
println!("v: {}, {:?}\n", tvar_v.1, tvar_v.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
let internal_resolver = Arc::new(ResolverInternal {
|
let internal_resolver = Arc::new(ResolverInternal {
|
||||||
id_to_def: Default::default(),
|
id_to_def: Default::default(),
|
||||||
id_to_type: Mutex::new(
|
id_to_type: Default::default(),
|
||||||
vec![("T".to_string(), tvar_t.0), ("V".to_string(), tvar_v.0)].into_iter().collect(),
|
|
||||||
),
|
|
||||||
class_names: Default::default(),
|
class_names: Default::default(),
|
||||||
});
|
});
|
||||||
let resolver = Arc::new(
|
let resolver = Arc::new(
|
||||||
|
@ -903,7 +920,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||||
let ast = parse_program(s).unwrap();
|
let ast = parse_program(s).unwrap();
|
||||||
let ast = ast[0].clone();
|
let ast = ast[0].clone();
|
||||||
|
|
||||||
let (id, def_id) = {
|
let (id, def_id, ty) = {
|
||||||
match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) {
|
match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) {
|
||||||
Ok(x) => x,
|
Ok(x) => x,
|
||||||
Err(msg) => {
|
Err(msg) => {
|
||||||
|
@ -916,12 +933,14 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
internal_resolver.add_id_def(id, def_id);
|
internal_resolver.add_id_def(id.clone(), def_id);
|
||||||
|
if let Some(ty) = ty {
|
||||||
|
internal_resolver.add_id_type(id, ty);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Err(msg) = composer.start_analysis(true) {
|
if let Err(msg) = composer.start_analysis(true) {
|
||||||
if print {
|
if print {
|
||||||
// println!("err2:");
|
|
||||||
println!("{}", msg);
|
println!("{}", msg);
|
||||||
} else {
|
} else {
|
||||||
assert_eq!(res[0], msg);
|
assert_eq!(res[0], msg);
|
||||||
|
@ -931,10 +950,15 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||||
for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() {
|
for (i, (def, _)) in composer.definition_ast_list.iter().skip(5).enumerate() {
|
||||||
let def = &*def.read();
|
let def = &*def.read();
|
||||||
|
|
||||||
if let TopLevelDef::Function { instance_to_stmt, .. } = def {
|
if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {
|
||||||
for inst in instance_to_stmt.iter() {
|
for inst in instance_to_stmt.iter() {
|
||||||
let ast = &inst.1.body;
|
let ast = &inst.1.body;
|
||||||
println!("{:?}", ast)
|
println!("{}:", name);
|
||||||
|
for b in ast {
|
||||||
|
println!("{:?}", b);
|
||||||
|
println!("--------------------");
|
||||||
|
}
|
||||||
|
println!("\n");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue