forked from M-Labs/nac3
nac3core: allow class to have no __init__, function/method name with module path added to ensure uniqueness
This commit is contained in:
parent
3c930ae9ab
commit
dd1be541b8
|
@ -684,7 +684,11 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
return self.gen_call(Some((value.custom.unwrap(), val)), (&signature, fun_id), params);
|
return self.gen_call(
|
||||||
|
Some((value.custom.unwrap(), val)),
|
||||||
|
(&signature, fun_id),
|
||||||
|
params,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
_ => unimplemented!(),
|
_ => unimplemented!(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,11 +17,10 @@ pub struct TopLevelComposer {
|
||||||
// keyword list to prevent same user-defined name
|
// keyword list to prevent same user-defined name
|
||||||
pub keyword_list: HashSet<String>,
|
pub keyword_list: HashSet<String>,
|
||||||
// to prevent duplicate definition
|
// to prevent duplicate definition
|
||||||
pub defined_class_name: HashSet<String>,
|
pub defined_names: HashSet<String>,
|
||||||
pub defined_class_method_name: HashSet<String>,
|
|
||||||
pub defined_function_name: HashSet<String>,
|
|
||||||
// get the class def id of a class method
|
// get the class def id of a class method
|
||||||
pub method_class: HashMap<DefinitionId, DefinitionId>,
|
pub method_class: HashMap<DefinitionId, DefinitionId>,
|
||||||
|
// number of built-in function and classes in the definition list, later skip
|
||||||
pub built_in_num: usize,
|
pub built_in_num: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,7 +51,7 @@ impl TopLevelComposer {
|
||||||
};
|
};
|
||||||
let primitives_ty = primitives.0;
|
let primitives_ty = primitives.0;
|
||||||
let mut unifier = primitives.1;
|
let mut unifier = primitives.1;
|
||||||
let keyword_list: HashSet<String> = HashSet::from_iter(vec![
|
let mut keyword_list: HashSet<String> = HashSet::from_iter(vec![
|
||||||
"Generic".into(),
|
"Generic".into(),
|
||||||
"virtual".into(),
|
"virtual".into(),
|
||||||
"list".into(),
|
"list".into(),
|
||||||
|
@ -67,9 +66,7 @@ impl TopLevelComposer {
|
||||||
"Kernel".into(),
|
"Kernel".into(),
|
||||||
"KernelImmutable".into(),
|
"KernelImmutable".into(),
|
||||||
]);
|
]);
|
||||||
let mut defined_class_method_name: HashSet<String> = Default::default();
|
let defined_names: HashSet<String> = Default::default();
|
||||||
let mut defined_class_name: HashSet<String> = Default::default();
|
|
||||||
let mut defined_function_name: HashSet<String> = Default::default();
|
|
||||||
let method_class: HashMap<DefinitionId, DefinitionId> = Default::default();
|
let method_class: HashMap<DefinitionId, DefinitionId> = Default::default();
|
||||||
|
|
||||||
let mut built_in_id: HashMap<String, DefinitionId> = Default::default();
|
let mut built_in_id: HashMap<String, DefinitionId> = Default::default();
|
||||||
|
@ -91,9 +88,7 @@ impl TopLevelComposer {
|
||||||
})),
|
})),
|
||||||
None,
|
None,
|
||||||
));
|
));
|
||||||
defined_class_method_name.insert(name.clone());
|
keyword_list.insert(name);
|
||||||
defined_class_name.insert(name.clone());
|
|
||||||
defined_function_name.insert(name);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
(
|
(
|
||||||
|
@ -103,9 +98,7 @@ impl TopLevelComposer {
|
||||||
primitives_ty,
|
primitives_ty,
|
||||||
unifier,
|
unifier,
|
||||||
keyword_list,
|
keyword_list,
|
||||||
defined_class_method_name,
|
defined_names,
|
||||||
defined_class_name,
|
|
||||||
defined_function_name,
|
|
||||||
method_class,
|
method_class,
|
||||||
},
|
},
|
||||||
built_in_id,
|
built_in_id,
|
||||||
|
@ -119,7 +112,7 @@ impl TopLevelComposer {
|
||||||
self.definition_ast_list.iter().map(|(x, ..)| x.clone()).collect_vec(),
|
self.definition_ast_list.iter().map(|(x, ..)| x.clone()).collect_vec(),
|
||||||
)
|
)
|
||||||
.into(),
|
.into(),
|
||||||
// FIXME: all the big unifier or?
|
// NOTE: only one for now
|
||||||
unifiers: Arc::new(RwLock::new(vec![(
|
unifiers: Arc::new(RwLock::new(vec![(
|
||||||
self.unifier.get_shared_unifier(),
|
self.unifier.get_shared_unifier(),
|
||||||
self.primitives_ty,
|
self.primitives_ty,
|
||||||
|
@ -139,23 +132,21 @@ impl TopLevelComposer {
|
||||||
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
|
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
|
||||||
mod_path: String,
|
mod_path: String,
|
||||||
) -> Result<(String, DefinitionId, Option<Type>), String> {
|
) -> Result<(String, DefinitionId, Option<Type>), String> {
|
||||||
let defined_class_name = &mut self.defined_class_name;
|
let defined_names = &mut self.defined_names;
|
||||||
let defined_class_method_name = &mut self.defined_class_method_name;
|
|
||||||
let defined_function_name = &mut self.defined_function_name;
|
|
||||||
match &ast.node {
|
match &ast.node {
|
||||||
ast::StmtKind::ClassDef { name, body, .. } => {
|
ast::StmtKind::ClassDef { name: class_name, body, .. } => {
|
||||||
if self.keyword_list.contains(name) {
|
if self.keyword_list.contains(class_name) {
|
||||||
return Err("cannot use keyword as a class name".into());
|
return Err("cannot use keyword as a class name".into());
|
||||||
}
|
}
|
||||||
if !defined_class_name.insert({
|
if !defined_names.insert({
|
||||||
let mut n = mod_path.clone();
|
let mut n = mod_path.clone();
|
||||||
n.push_str(name.as_str());
|
n.push_str(class_name.as_str());
|
||||||
n
|
n
|
||||||
}) {
|
}) {
|
||||||
return Err("duplicate definition of class".into());
|
return Err("duplicate definition of class".into());
|
||||||
}
|
}
|
||||||
|
|
||||||
let class_name = name.to_string();
|
let class_name = class_name.clone();
|
||||||
let class_def_id = self.definition_ast_list.len();
|
let class_def_id = self.definition_ast_list.len();
|
||||||
|
|
||||||
// since later when registering class method, ast will still be used,
|
// since later when registering class method, ast will still be used,
|
||||||
|
@ -165,8 +156,8 @@ impl TopLevelComposer {
|
||||||
Arc::new(RwLock::new(Self::make_top_level_class_def(
|
Arc::new(RwLock::new(Self::make_top_level_class_def(
|
||||||
class_def_id,
|
class_def_id,
|
||||||
resolver.clone(),
|
resolver.clone(),
|
||||||
name,
|
class_name.as_str(),
|
||||||
Some(constructor_ty)
|
Some(constructor_ty),
|
||||||
))),
|
))),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -187,23 +178,27 @@ impl TopLevelComposer {
|
||||||
// we do not push anything to the def list, so we keep track of the index
|
// we do not push anything to the def list, so we keep track of the index
|
||||||
// and then push in the correct order after the for loop
|
// and then push in the correct order after the for loop
|
||||||
let mut class_method_index_offset = 0;
|
let mut class_method_index_offset = 0;
|
||||||
let mut has_init = false;
|
|
||||||
for b in body {
|
for b in body {
|
||||||
if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node {
|
if let ast::StmtKind::FunctionDef { name: method_name, .. } = &b.node {
|
||||||
if self.keyword_list.contains(name) {
|
if self.keyword_list.contains(method_name) {
|
||||||
return Err("cannot use keyword as a method name".into());
|
return Err("cannot use keyword as a method name".into());
|
||||||
}
|
}
|
||||||
let global_class_method_name =
|
if method_name.ends_with(|x: char| x.is_ascii_digit()) {
|
||||||
Self::make_class_method_name(class_name.clone(), method_name);
|
return Err(format!(
|
||||||
if !defined_class_method_name.insert({
|
"function name `{}` must not end with numbers",
|
||||||
let mut n = mod_path.clone();
|
method_name
|
||||||
n.push_str(global_class_method_name.as_str());
|
));
|
||||||
n
|
|
||||||
}) {
|
|
||||||
return Err("duplicate class method definition".into());
|
|
||||||
}
|
}
|
||||||
if method_name == "__init__" {
|
let global_class_method_name = {
|
||||||
has_init = true;
|
let mut n = mod_path.clone();
|
||||||
|
n.push_str(
|
||||||
|
Self::make_class_method_name(class_name.clone(), method_name)
|
||||||
|
.as_str(),
|
||||||
|
);
|
||||||
|
n
|
||||||
|
};
|
||||||
|
if !defined_names.insert(global_class_method_name.clone()) {
|
||||||
|
return Err("duplicate class method definition".into());
|
||||||
}
|
}
|
||||||
let method_def_id = self.definition_ast_list.len() + {
|
let method_def_id = self.definition_ast_list.len() + {
|
||||||
// plus 1 here since we already have the class def
|
// plus 1 here since we already have the class def
|
||||||
|
@ -232,9 +227,6 @@ impl TopLevelComposer {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !has_init {
|
|
||||||
return Err("class def must have __init__ method defined".into());
|
|
||||||
}
|
|
||||||
|
|
||||||
// move the ast to the entry of the class in the ast_list
|
// move the ast to the entry of the class in the ast_list
|
||||||
class_def_ast.1 = Some(ast);
|
class_def_ast.1 = Some(ast);
|
||||||
|
@ -261,12 +253,16 @@ impl TopLevelComposer {
|
||||||
if self.keyword_list.contains(name) {
|
if self.keyword_list.contains(name) {
|
||||||
return Err("cannot use keyword as a top level function name".into());
|
return Err("cannot use keyword as a top level function name".into());
|
||||||
}
|
}
|
||||||
|
if name.ends_with(|x: char| x.is_ascii_digit()) {
|
||||||
|
return Err(format!("function name `{}` must not end with numbers", name));
|
||||||
|
}
|
||||||
let fun_name = name.to_string();
|
let fun_name = name.to_string();
|
||||||
if !defined_function_name.insert({
|
let global_fun_name = {
|
||||||
let mut n = mod_path;
|
let mut n = mod_path;
|
||||||
n.push_str(name.as_str());
|
n.push_str(name.as_str());
|
||||||
n
|
n
|
||||||
}) {
|
};
|
||||||
|
if !defined_names.insert(global_fun_name.clone()) {
|
||||||
return Err("duplicate top level function define".into());
|
return Err("duplicate top level function define".into());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -274,8 +270,7 @@ impl TopLevelComposer {
|
||||||
// add to the definition list
|
// add to the definition list
|
||||||
self.definition_ast_list.push((
|
self.definition_ast_list.push((
|
||||||
RwLock::new(Self::make_top_level_function_def(
|
RwLock::new(Self::make_top_level_function_def(
|
||||||
// TODO: is this fun_name or the above name with mod_path?
|
global_fun_name,
|
||||||
name.into(),
|
|
||||||
name.into(),
|
name.into(),
|
||||||
// dummy here, unify with correct type later
|
// dummy here, unify with correct type later
|
||||||
ty_to_be_unified,
|
ty_to_be_unified,
|
||||||
|
@ -824,7 +819,8 @@ impl TopLevelComposer {
|
||||||
|
|
||||||
let mut defined_fields: HashSet<String> = HashSet::new();
|
let mut defined_fields: HashSet<String> = HashSet::new();
|
||||||
for b in class_body_ast {
|
for b in class_body_ast {
|
||||||
if let ast::StmtKind::FunctionDef { args, returns, name, .. } = &b.node {
|
match &b.node {
|
||||||
|
ast::StmtKind::FunctionDef { args, returns, name, .. } => {
|
||||||
let (method_dummy_ty, method_id) =
|
let (method_dummy_ty, method_id) =
|
||||||
Self::get_class_method_def_info(class_methods_def, name)?;
|
Self::get_class_method_def_info(class_methods_def, name)?;
|
||||||
|
|
||||||
|
@ -905,7 +901,8 @@ impl TopLevelComposer {
|
||||||
};
|
};
|
||||||
// push the dummy type and the type annotation
|
// push the dummy type and the type annotation
|
||||||
// into the list for later unification
|
// into the list for later unification
|
||||||
type_var_to_concrete_def.insert(dummy_func_arg.ty, type_ann.clone());
|
type_var_to_concrete_def
|
||||||
|
.insert(dummy_func_arg.ty, type_ann.clone());
|
||||||
result.push(dummy_func_arg)
|
result.push(dummy_func_arg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -961,14 +958,15 @@ impl TopLevelComposer {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let method_type = unifier.add_ty(TypeEnum::TFunc(
|
let method_type = unifier.add_ty(TypeEnum::TFunc(
|
||||||
FunSignature { args: arg_types, ret: ret_type, vars: method_var_map }.into(),
|
FunSignature { args: arg_types, ret: ret_type, vars: method_var_map }
|
||||||
|
.into(),
|
||||||
));
|
));
|
||||||
|
|
||||||
// unify now since function type is not in type annotation define
|
// unify now since function type is not in type annotation define
|
||||||
// which should be fine since type within method_type will be subst later
|
// which should be fine since type within method_type will be subst later
|
||||||
unifier.unify(method_dummy_ty, method_type)?;
|
unifier.unify(method_dummy_ty, method_type)?;
|
||||||
} else if let ast::StmtKind::AnnAssign { target, annotation, value: None, .. } = &b.node
|
}
|
||||||
{
|
ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => {
|
||||||
if let ast::ExprKind::Name { id: attr, .. } = &target.node {
|
if let ast::ExprKind::Name { id: attr, .. } = &target.node {
|
||||||
if defined_fields.insert(attr.to_string()) {
|
if defined_fields.insert(attr.to_string()) {
|
||||||
let dummy_field_type = unifier.get_fresh_var().0;
|
let dummy_field_type = unifier.get_fresh_var().0;
|
||||||
|
@ -1015,11 +1013,14 @@ impl TopLevelComposer {
|
||||||
} else {
|
} else {
|
||||||
return Err("same class fields defined twice".into());
|
return Err("same class fields defined twice".into());
|
||||||
}
|
}
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
return Err("unsupported statement type in class definition body".into());
|
return Err("unsupported statement type in class definition body".into());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ast::StmtKind::Pass => {}
|
||||||
|
_ => return Err("unsupported statement type in class definition body".into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1162,8 +1163,14 @@ impl TopLevelComposer {
|
||||||
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num)
|
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.built_in_num)
|
||||||
{
|
{
|
||||||
let mut function_def = def.write();
|
let mut function_def = def.write();
|
||||||
if let TopLevelDef::Function { instance_to_stmt, name, simple_name, signature, resolver, .. } =
|
if let TopLevelDef::Function {
|
||||||
&mut *function_def
|
instance_to_stmt,
|
||||||
|
name,
|
||||||
|
simple_name,
|
||||||
|
signature,
|
||||||
|
resolver,
|
||||||
|
..
|
||||||
|
} = &mut *function_def
|
||||||
{
|
{
|
||||||
if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() {
|
if let TypeEnum::TFunc(func_sig) = self.unifier.get_ty(*signature).as_ref() {
|
||||||
let FunSignature { args, ret, vars } = &*func_sig.borrow();
|
let FunSignature { args, ret, vars } = &*func_sig.borrow();
|
||||||
|
@ -1181,11 +1188,13 @@ impl TopLevelComposer {
|
||||||
&ty_ann,
|
&ty_ann,
|
||||||
)?;
|
)?;
|
||||||
if simple_name == "__init__" {
|
if simple_name == "__init__" {
|
||||||
let fn_type = self.unifier.add_ty(TypeEnum::TFunc(RefCell::new(FunSignature {
|
let fn_type = self.unifier.add_ty(TypeEnum::TFunc(
|
||||||
|
RefCell::new(FunSignature {
|
||||||
args: args.clone(),
|
args: args.clone(),
|
||||||
ret: self_ty,
|
ret: self_ty,
|
||||||
vars: vars.clone()
|
vars: vars.clone(),
|
||||||
})));
|
}),
|
||||||
|
));
|
||||||
self.unifier.unify(fn_type, constructor.unwrap())?;
|
self.unifier.unify(fn_type, constructor.unwrap())?;
|
||||||
}
|
}
|
||||||
Some(self_ty)
|
Some(self_ty)
|
||||||
|
|
|
@ -93,7 +93,7 @@ impl TopLevelComposer {
|
||||||
index: usize,
|
index: usize,
|
||||||
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
|
resolver: Option<Arc<Box<dyn SymbolResolver + Send + Sync>>>,
|
||||||
name: &str,
|
name: &str,
|
||||||
constructor: Option<Type>
|
constructor: Option<Type>,
|
||||||
) -> TopLevelDef {
|
) -> TopLevelDef {
|
||||||
TopLevelDef::Class {
|
TopLevelDef::Class {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
|
|
|
@ -94,7 +94,7 @@ fn test_simple_register(source: Vec<&str>) {
|
||||||
let ast = parse_program(s).unwrap();
|
let ast = parse_program(s).unwrap();
|
||||||
let ast = ast[0].clone();
|
let ast = ast[0].clone();
|
||||||
|
|
||||||
composer.register_top_level(ast, None, "__main__".into()).unwrap();
|
composer.register_top_level(ast, None, "".into()).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
|
||||||
let ast = ast[0].clone();
|
let ast = ast[0].clone();
|
||||||
|
|
||||||
let (id, def_id, ty) =
|
let (id, def_id, ty) =
|
||||||
composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()).unwrap();
|
composer.register_top_level(ast, Some(resolver.clone()), "".into()).unwrap();
|
||||||
internal_resolver.add_id_def(id.clone(), def_id);
|
internal_resolver.add_id_def(id.clone(), def_id);
|
||||||
if let Some(ty) = ty {
|
if let Some(ty) = ty {
|
||||||
internal_resolver.add_id_type(id, ty);
|
internal_resolver.add_id_type(id, ty);
|
||||||
|
@ -151,7 +151,8 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
|
||||||
|
|
||||||
composer.start_analysis(true).unwrap();
|
composer.start_analysis(true).unwrap();
|
||||||
|
|
||||||
for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() {
|
for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate()
|
||||||
|
{
|
||||||
let def = &*def.read();
|
let def = &*def.read();
|
||||||
if let TopLevelDef::Function { signature, name, .. } = def {
|
if let TopLevelDef::Function { signature, name, .. } = def {
|
||||||
let ty_str =
|
let ty_str =
|
||||||
|
@ -638,12 +639,23 @@ fn test_simple_function_analyze(source: Vec<&str>, tys: Vec<&str>, names: Vec<&s
|
||||||
"cyclic2"
|
"cyclic2"
|
||||||
)]
|
)]
|
||||||
#[test_case(
|
#[test_case(
|
||||||
vec![indoc! {"
|
vec![
|
||||||
|
indoc! {"
|
||||||
class A:
|
class A:
|
||||||
pass
|
pass
|
||||||
|
"}
|
||||||
|
],
|
||||||
|
vec!["5: Class {\nname: \"A\",\ndef_id: DefinitionId(5),\nancestors: [CustomClassKind { id: DefinitionId(5), params: [] }],\nfields: [],\nmethods: [],\ntype_vars: []\n}"];
|
||||||
|
"simple pass in class"
|
||||||
|
)]
|
||||||
|
#[test_case(
|
||||||
|
vec![indoc! {"
|
||||||
|
class A:
|
||||||
|
def fun3(self):
|
||||||
|
pass
|
||||||
"}],
|
"}],
|
||||||
vec!["class def must have __init__ method defined"];
|
vec!["function name `fun3` must not end with numbers"];
|
||||||
"err no __init__"
|
"err fun end with number"
|
||||||
)]
|
)]
|
||||||
#[test_case(
|
#[test_case(
|
||||||
vec![indoc! {"
|
vec![indoc! {"
|
||||||
|
@ -755,7 +767,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||||
("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]),
|
("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]),
|
||||||
],
|
],
|
||||||
&mut composer.unifier,
|
&mut composer.unifier,
|
||||||
print
|
print,
|
||||||
);
|
);
|
||||||
let resolver = Arc::new(
|
let resolver = Arc::new(
|
||||||
Box::new(Resolver(internal_resolver.clone())) as Box<dyn SymbolResolver + Send + Sync>
|
Box::new(Resolver(internal_resolver.clone())) as Box<dyn SymbolResolver + Send + Sync>
|
||||||
|
@ -766,7 +778,7 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||||
let ast = ast[0].clone();
|
let ast = ast[0].clone();
|
||||||
|
|
||||||
let (id, def_id, ty) = {
|
let (id, def_id, ty) = {
|
||||||
match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) {
|
match composer.register_top_level(ast, Some(resolver.clone()), "".into()) {
|
||||||
Ok(x) => x,
|
Ok(x) => x,
|
||||||
Err(msg) => {
|
Err(msg) => {
|
||||||
if print {
|
if print {
|
||||||
|
@ -792,7 +804,9 @@ fn test_analyze(source: Vec<&str>, res: Vec<&str>) {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// skip 5 to skip primitives
|
// skip 5 to skip primitives
|
||||||
for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() {
|
for (i, (def, _)) in
|
||||||
|
composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate()
|
||||||
|
{
|
||||||
let def = &*def.read();
|
let def = &*def.read();
|
||||||
|
|
||||||
if print {
|
if print {
|
||||||
|
@ -923,11 +937,18 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||||
let internal_resolver = make_internal_resolver_with_tvar(
|
let internal_resolver = make_internal_resolver_with_tvar(
|
||||||
vec![
|
vec![
|
||||||
("T".into(), vec![]),
|
("T".into(), vec![]),
|
||||||
("V".into(), vec![composer.primitives_ty.float, composer.primitives_ty.int32, composer.primitives_ty.int64]),
|
(
|
||||||
|
"V".into(),
|
||||||
|
vec![
|
||||||
|
composer.primitives_ty.float,
|
||||||
|
composer.primitives_ty.int32,
|
||||||
|
composer.primitives_ty.int64,
|
||||||
|
],
|
||||||
|
),
|
||||||
("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]),
|
("G".into(), vec![composer.primitives_ty.bool, composer.primitives_ty.int64]),
|
||||||
],
|
],
|
||||||
&mut composer.unifier,
|
&mut composer.unifier,
|
||||||
print
|
print,
|
||||||
);
|
);
|
||||||
let resolver = Arc::new(
|
let resolver = Arc::new(
|
||||||
Box::new(Resolver(internal_resolver.clone())) as Box<dyn SymbolResolver + Send + Sync>
|
Box::new(Resolver(internal_resolver.clone())) as Box<dyn SymbolResolver + Send + Sync>
|
||||||
|
@ -938,7 +959,7 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||||
let ast = ast[0].clone();
|
let ast = ast[0].clone();
|
||||||
|
|
||||||
let (id, def_id, ty) = {
|
let (id, def_id, ty) = {
|
||||||
match composer.register_top_level(ast, Some(resolver.clone()), "__main__".into()) {
|
match composer.register_top_level(ast, Some(resolver.clone()), "".into()) {
|
||||||
Ok(x) => x,
|
Ok(x) => x,
|
||||||
Err(msg) => {
|
Err(msg) => {
|
||||||
if print {
|
if print {
|
||||||
|
@ -964,12 +985,18 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// skip 5 to skip primitives
|
// skip 5 to skip primitives
|
||||||
let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier};
|
let mut stringify_folder = TypeToStringFolder { unifier: &mut composer.unifier };
|
||||||
for (i, (def, _)) in composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate() {
|
for (_i, (def, _)) in
|
||||||
|
composer.definition_ast_list.iter().skip(composer.built_in_num).enumerate()
|
||||||
|
{
|
||||||
let def = &*def.read();
|
let def = &*def.read();
|
||||||
|
|
||||||
if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {
|
if let TopLevelDef::Function { instance_to_stmt, name, .. } = def {
|
||||||
println!("=========`{}`: number of instances: {}===========", name, instance_to_stmt.len());
|
println!(
|
||||||
|
"=========`{}`: number of instances: {}===========",
|
||||||
|
name,
|
||||||
|
instance_to_stmt.len()
|
||||||
|
);
|
||||||
for inst in instance_to_stmt.iter() {
|
for inst in instance_to_stmt.iter() {
|
||||||
let ast = &inst.1.body;
|
let ast = &inst.1.body;
|
||||||
for b in ast {
|
for b in ast {
|
||||||
|
@ -983,25 +1010,29 @@ fn test_inference(source: Vec<&str>, res: Vec<&str>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn make_internal_resolver_with_tvar(tvars: Vec<(String, Vec<Type>)>, unifier: &mut Unifier, print: bool) -> Arc<ResolverInternal> {
|
fn make_internal_resolver_with_tvar(
|
||||||
|
tvars: Vec<(String, Vec<Type>)>,
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
print: bool,
|
||||||
|
) -> Arc<ResolverInternal> {
|
||||||
let res: Arc<ResolverInternal> = ResolverInternal {
|
let res: Arc<ResolverInternal> = ResolverInternal {
|
||||||
id_to_def: Default::default(),
|
id_to_def: Default::default(),
|
||||||
id_to_type: tvars
|
id_to_type: tvars
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(name, range)| (
|
.map(|(name, range)| {
|
||||||
name.clone(),
|
(name.clone(), {
|
||||||
{
|
|
||||||
let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice());
|
let (ty, id) = unifier.get_fresh_var_with_range(range.as_slice());
|
||||||
if print {
|
if print {
|
||||||
println!("{}: {:?}, tvar{}", name, ty, id);
|
println!("{}: {:?}, tvar{}", name, ty, id);
|
||||||
}
|
}
|
||||||
ty
|
ty
|
||||||
}
|
})
|
||||||
))
|
})
|
||||||
.collect::<HashMap<_, _>>()
|
.collect::<HashMap<_, _>>()
|
||||||
.into(),
|
.into(),
|
||||||
class_names: Default::default()
|
class_names: Default::default(),
|
||||||
}.into();
|
}
|
||||||
|
.into();
|
||||||
if print {
|
if print {
|
||||||
println!();
|
println!();
|
||||||
}
|
}
|
||||||
|
@ -1009,7 +1040,7 @@ fn make_internal_resolver_with_tvar(tvars: Vec<(String, Vec<Type>)>, unifier: &m
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TypeToStringFolder<'a> {
|
struct TypeToStringFolder<'a> {
|
||||||
unifier: &'a mut Unifier
|
unifier: &'a mut Unifier,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Fold<Option<Type>> for TypeToStringFolder<'a> {
|
impl<'a> Fold<Option<Type>> for TypeToStringFolder<'a> {
|
||||||
|
@ -1017,14 +1048,11 @@ impl<'a> Fold<Option<Type>> for TypeToStringFolder<'a> {
|
||||||
type Error = String;
|
type Error = String;
|
||||||
fn map_user(&mut self, user: Option<Type>) -> Result<Self::TargetU, Self::Error> {
|
fn map_user(&mut self, user: Option<Type>) -> Result<Self::TargetU, Self::Error> {
|
||||||
Ok(if let Some(ty) = user {
|
Ok(if let Some(ty) = user {
|
||||||
self.unifier.stringify(
|
self.unifier.stringify(ty, &mut |id| format!("class{}", id.to_string()), &mut |id| {
|
||||||
ty,
|
format!("tvar{}", id.to_string())
|
||||||
&mut |id| format!("class{}", id.to_string()),
|
})
|
||||||
&mut |id| format!("tvar{}", id.to_string()),
|
|
||||||
)
|
|
||||||
} else {
|
} else {
|
||||||
"None".into()
|
"None".into()
|
||||||
}
|
})
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -797,7 +797,12 @@ impl Unifier {
|
||||||
self.subst_impl(a, mapping, &mut HashMap::new())
|
self.subst_impl(a, mapping, &mut HashMap::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn subst_impl(&mut self, a: Type, mapping: &VarMap, cache: &mut HashMap<Type, Option<Type>>) -> Option<Type> {
|
fn subst_impl(
|
||||||
|
&mut self,
|
||||||
|
a: Type,
|
||||||
|
mapping: &VarMap,
|
||||||
|
cache: &mut HashMap<Type, Option<Type>>,
|
||||||
|
) -> Option<Type> {
|
||||||
use TypeVarMeta::*;
|
use TypeVarMeta::*;
|
||||||
let cached = cache.get_mut(&a);
|
let cached = cache.get_mut(&a);
|
||||||
if let Some(cached) = cached {
|
if let Some(cached) = cached {
|
||||||
|
@ -831,9 +836,9 @@ impl Unifier {
|
||||||
TypeEnum::TList { ty } => {
|
TypeEnum::TList { ty } => {
|
||||||
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
||||||
}
|
}
|
||||||
TypeEnum::TVirtual { ty } => {
|
TypeEnum::TVirtual { ty } => self
|
||||||
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TVirtual { ty: t }))
|
.subst_impl(*ty, mapping, cache)
|
||||||
}
|
.map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })),
|
||||||
TypeEnum::TObj { obj_id, fields, params } => {
|
TypeEnum::TObj { obj_id, fields, params } => {
|
||||||
// Type variables in field types must be present in the type parameter.
|
// Type variables in field types must be present in the type parameter.
|
||||||
// If the mapping does not contain any type variables in the
|
// If the mapping does not contain any type variables in the
|
||||||
|
@ -851,7 +856,8 @@ impl Unifier {
|
||||||
if need_subst {
|
if need_subst {
|
||||||
cache.insert(a, None);
|
cache.insert(a, None);
|
||||||
let obj_id = *obj_id;
|
let obj_id = *obj_id;
|
||||||
let params = self.subst_map(¶ms, mapping, cache).unwrap_or_else(|| params.clone());
|
let params =
|
||||||
|
self.subst_map(¶ms, mapping, cache).unwrap_or_else(|| params.clone());
|
||||||
let fields = self
|
let fields = self
|
||||||
.subst_map(&fields.borrow(), mapping, cache)
|
.subst_map(&fields.borrow(), mapping, cache)
|
||||||
.unwrap_or_else(|| fields.borrow().clone());
|
.unwrap_or_else(|| fields.borrow().clone());
|
||||||
|
@ -897,7 +903,12 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn subst_map<K>(&mut self, map: &Mapping<K>, mapping: &VarMap, cache: &mut HashMap<Type, Option<Type>>) -> Option<Mapping<K>>
|
fn subst_map<K>(
|
||||||
|
&mut self,
|
||||||
|
map: &Mapping<K>,
|
||||||
|
mapping: &VarMap,
|
||||||
|
cache: &mut HashMap<Type, Option<Type>>,
|
||||||
|
) -> Option<Mapping<K>>
|
||||||
where
|
where
|
||||||
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
K: std::hash::Hash + std::cmp::Eq + std::clone::Clone,
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue