Compare commits
3 Commits
master
...
refactor_c
Author | SHA1 | Date | |
---|---|---|---|
45ae761ed9 | |||
597eaa0873 | |||
5c9f688d9e |
@ -1,5 +1,6 @@
|
||||
use std::rc::Rc;
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use nac3parser::ast::{fold::Fold, ExprKind};
|
||||
|
||||
use super::*;
|
||||
@ -439,9 +440,9 @@ impl TopLevelComposer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyze the AST and modify the corresponding `TopLevelDef`
|
||||
pub fn start_analysis(&mut self, inference: bool) -> Result<(), HashSet<String>> {
|
||||
self.analyze_top_level_class_type_var()?;
|
||||
self.analyze_top_level_class_bases()?;
|
||||
self.analyze_top_level_class_definition()?;
|
||||
self.analyze_top_level_class_fields_methods()?;
|
||||
self.analyze_top_level_function()?;
|
||||
if inference {
|
||||
@ -451,442 +452,184 @@ impl TopLevelComposer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// step 1, analyze the type vars associated with top level class
|
||||
fn analyze_top_level_class_type_var(&mut self) -> Result<(), HashSet<String>> {
|
||||
/// step 1, analyze the top level class definitions
|
||||
///
|
||||
/// Checks for class type variables and ancestors adding them to the `TopLevelDef` list
|
||||
fn analyze_top_level_class_definition(&mut self) -> Result<(), HashSet<String>> {
|
||||
let def_list = &self.definition_ast_list;
|
||||
let temp_def_list = self.extract_def_list();
|
||||
let unifier = self.unifier.borrow_mut();
|
||||
let primitives_store = &self.primitives_ty;
|
||||
let mut errors = HashSet::new();
|
||||
|
||||
let mut analyze = |class_def: &Arc<RwLock<TopLevelDef>>, class_ast: &Option<Stmt>| {
|
||||
// only deal with class def here
|
||||
let mut class_def = class_def.write();
|
||||
let (class_bases_ast, class_def_type_vars, class_resolver) = {
|
||||
if let TopLevelDef::Class { type_vars, resolver, .. } = &mut *class_def {
|
||||
let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) =
|
||||
class_ast
|
||||
// Initially only copy the definitions of buitin classes and functions
|
||||
// class definitions are added in the same order as they appear in the program
|
||||
let mut temp_def_list: Vec<Arc<RwLock<TopLevelDef>>> =
|
||||
def_list.iter().take(self.builtin_num).map(|f| f.0.clone()).collect_vec();
|
||||
|
||||
// Check for class generic variables and ancestors
|
||||
for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) {
|
||||
if class_ast.is_some() && matches!(&*class_def.read(), TopLevelDef::Class { .. }) {
|
||||
// Add class type variables and direct parents to the `TopLevelDef`
|
||||
if let Err(e) = Self::analyze_class_bases(
|
||||
class_def,
|
||||
class_ast,
|
||||
&temp_def_list,
|
||||
unifier,
|
||||
primitives_store,
|
||||
) {
|
||||
errors.extend(e);
|
||||
}
|
||||
|
||||
// Add class ancestors
|
||||
Self::analyze_class_ancestors(class_def, &temp_def_list);
|
||||
|
||||
// special case classes that inherit from Exception
|
||||
let TopLevelDef::Class { ancestors: class_ancestors, .. } = &*class_def.read()
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
if class_ancestors
|
||||
.iter()
|
||||
.any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7))
|
||||
{
|
||||
// if inherited from Exception, the body should be a pass
|
||||
let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
(bases, type_vars, resolver)
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let class_resolver = class_resolver.as_ref().unwrap();
|
||||
let class_resolver = &**class_resolver;
|
||||
|
||||
let mut is_generic = false;
|
||||
for b in class_bases_ast {
|
||||
match &b.node {
|
||||
// analyze typevars bounded to the class,
|
||||
// only support things like `class A(Generic[T, V])`,
|
||||
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported
|
||||
// i.e. only simple names are allowed in the subscript
|
||||
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
|
||||
ast::ExprKind::Subscript { value, slice, .. }
|
||||
if {
|
||||
matches!(
|
||||
&value.node,
|
||||
ast::ExprKind::Name { id, .. } if id == &"Generic".into()
|
||||
)
|
||||
} =>
|
||||
{
|
||||
if is_generic {
|
||||
return Err(HashSet::from([format!(
|
||||
"only single Generic[...] is allowed (at {})",
|
||||
b.location
|
||||
)]));
|
||||
for stmt in body {
|
||||
if matches!(
|
||||
stmt.node,
|
||||
ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. }
|
||||
) {
|
||||
errors.extend(Err(HashSet::from(["Classes inherited from exception should have no custom fields/methods"])));
|
||||
}
|
||||
is_generic = true;
|
||||
|
||||
let type_var_list: Vec<&ast::Expr<()>>;
|
||||
// if `class A(Generic[T, V, G])`
|
||||
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
|
||||
type_var_list = elts.iter().collect_vec();
|
||||
// `class A(Generic[T])`
|
||||
} else {
|
||||
type_var_list = vec![&**slice];
|
||||
}
|
||||
|
||||
// parse the type vars
|
||||
let type_vars = type_var_list
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
class_resolver.parse_type_annotation(
|
||||
&temp_def_list,
|
||||
unifier,
|
||||
primitives_store,
|
||||
e,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// check if all are unique type vars
|
||||
let all_unique_type_var = {
|
||||
let mut occurred_type_var_id: HashSet<TypeVarId> = HashSet::new();
|
||||
type_vars.iter().all(|x| {
|
||||
let ty = unifier.get_ty(*x);
|
||||
if let TypeEnum::TVar { id, .. } = ty.as_ref() {
|
||||
occurred_type_var_id.insert(*id)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
};
|
||||
if !all_unique_type_var {
|
||||
return Err(HashSet::from([format!(
|
||||
"duplicate type variable occurs (at {})",
|
||||
slice.location
|
||||
)]));
|
||||
}
|
||||
|
||||
// add to TopLevelDef
|
||||
class_def_type_vars.extend(type_vars);
|
||||
}
|
||||
|
||||
// if others, do nothing in this function
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
|
||||
let mut errors = HashSet::new();
|
||||
for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) {
|
||||
if class_ast.is_none() {
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = analyze(class_def, class_ast) {
|
||||
errors.extend(e);
|
||||
}
|
||||
temp_def_list.push(class_def.clone());
|
||||
}
|
||||
|
||||
// deal with ancestors of Exception object
|
||||
let TopLevelDef::Class { name, ancestors, object_id, .. } = &mut *def_list[7].0.write()
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
assert_eq!(*name, "Exception".into());
|
||||
ancestors.push(make_self_type_annotation(&[], *object_id));
|
||||
|
||||
if !errors.is_empty() {
|
||||
return Err(errors);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// step 2, base classes.
|
||||
/// now that the type vars of all classes are done, handle base classes and
|
||||
/// put Self class into the ancestors list. We only allow single inheritance
|
||||
fn analyze_top_level_class_bases(&mut self) -> Result<(), HashSet<String>> {
|
||||
/// step 2, class fields and methods
|
||||
fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), HashSet<String>> {
|
||||
// Allow resolving definition IDs in error messages
|
||||
if self.unifier.top_level.is_none() {
|
||||
let ctx = Arc::new(self.make_top_level_context());
|
||||
self.unifier.top_level = Some(ctx);
|
||||
}
|
||||
|
||||
let def_list = &self.definition_ast_list;
|
||||
let temp_def_list = self.extract_def_list();
|
||||
let unifier = self.unifier.borrow_mut();
|
||||
let primitive_types = self.primitives_ty;
|
||||
|
||||
let mut get_direct_parents =
|
||||
|class_def: &Arc<RwLock<TopLevelDef>>, class_ast: &Option<Stmt>| {
|
||||
let mut class_def = class_def.write();
|
||||
let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = {
|
||||
if let TopLevelDef::Class {
|
||||
ancestors, resolver, object_id, type_vars, ..
|
||||
} = &mut *class_def
|
||||
{
|
||||
let Some(ast::Located {
|
||||
node: ast::StmtKind::ClassDef { bases, .. }, ..
|
||||
}) = class_ast
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
(object_id, bases, ancestors, resolver, type_vars)
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
let class_resolver = class_resolver.as_ref().unwrap();
|
||||
let class_resolver = &**class_resolver;
|
||||
|
||||
let mut has_base = false;
|
||||
for b in class_bases {
|
||||
// type vars have already been handled, so skip on `Generic[...]`
|
||||
if matches!(
|
||||
&b.node,
|
||||
ast::ExprKind::Subscript { value, .. }
|
||||
if matches!(
|
||||
&value.node,
|
||||
ast::ExprKind::Name { id, .. } if id == &"Generic".into()
|
||||
)
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if has_base {
|
||||
return Err(HashSet::from([format!(
|
||||
"a class definition can only have at most one base class \
|
||||
declaration and one generic declaration (at {})",
|
||||
b.location
|
||||
)]));
|
||||
}
|
||||
has_base = true;
|
||||
|
||||
// the function parse_ast_to make sure that no type var occurred in
|
||||
// bast_ty if it is a CustomClassKind
|
||||
let base_ty = parse_ast_to_type_annotation_kinds(
|
||||
class_resolver,
|
||||
&temp_def_list,
|
||||
unifier,
|
||||
&primitive_types,
|
||||
b,
|
||||
vec![(*class_def_id, class_type_vars.clone())]
|
||||
.into_iter()
|
||||
.collect::<HashMap<_, _>>(),
|
||||
)?;
|
||||
|
||||
if let TypeAnnotation::CustomClass { .. } = &base_ty {
|
||||
class_ancestors.push(base_ty);
|
||||
} else {
|
||||
return Err(HashSet::from([format!(
|
||||
"class base declaration can only be custom class (at {})",
|
||||
b.location,
|
||||
)]));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
|
||||
// first, only push direct parent into the list
|
||||
let mut errors = HashSet::new();
|
||||
for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) {
|
||||
if class_ast.is_none() {
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = get_direct_parents(class_def, class_ast) {
|
||||
errors.extend(e);
|
||||
}
|
||||
}
|
||||
if !errors.is_empty() {
|
||||
return Err(errors);
|
||||
}
|
||||
|
||||
// second, get all ancestors
|
||||
let mut ancestors_store: HashMap<DefinitionId, Vec<TypeAnnotation>> = HashMap::default();
|
||||
let mut get_all_ancestors =
|
||||
|class_def: &Arc<RwLock<TopLevelDef>>| -> Result<(), HashSet<String>> {
|
||||
let class_def = class_def.read();
|
||||
let (class_ancestors, class_id) = {
|
||||
if let TopLevelDef::Class { ancestors, object_id, .. } = &*class_def {
|
||||
(ancestors, *object_id)
|
||||
} else {
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
ancestors_store.insert(
|
||||
class_id,
|
||||
// if class has direct parents, get all ancestors of its parents. Else just empty
|
||||
if class_ancestors.is_empty() {
|
||||
vec![]
|
||||
} else {
|
||||
Self::get_all_ancestors_helper(
|
||||
&class_ancestors[0],
|
||||
temp_def_list.as_slice(),
|
||||
)?
|
||||
},
|
||||
);
|
||||
Ok(())
|
||||
};
|
||||
for (class_def, ast) in self.definition_ast_list.iter().skip(self.builtin_num) {
|
||||
if ast.is_none() {
|
||||
continue;
|
||||
}
|
||||
if let Err(e) = get_all_ancestors(class_def) {
|
||||
errors.extend(e);
|
||||
}
|
||||
}
|
||||
if !errors.is_empty() {
|
||||
return Err(errors);
|
||||
}
|
||||
|
||||
// insert the ancestors to the def list
|
||||
for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) {
|
||||
if class_ast.is_none() {
|
||||
continue;
|
||||
}
|
||||
let mut class_def = class_def.write();
|
||||
let (class_ancestors, class_id, class_type_vars) = {
|
||||
if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = &mut *class_def
|
||||
{
|
||||
(ancestors, *object_id, type_vars)
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let ans = ancestors_store.get_mut(&class_id).unwrap();
|
||||
class_ancestors.append(ans);
|
||||
|
||||
// insert self type annotation to the front of the vector to maintain the order
|
||||
class_ancestors
|
||||
.insert(0, make_self_type_annotation(class_type_vars.as_slice(), class_id));
|
||||
|
||||
// special case classes that inherit from Exception
|
||||
if class_ancestors
|
||||
.iter()
|
||||
.any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7))
|
||||
{
|
||||
// if inherited from Exception, the body should be a pass
|
||||
let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
for stmt in body {
|
||||
if matches!(
|
||||
stmt.node,
|
||||
ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. }
|
||||
) {
|
||||
return Err(HashSet::from([
|
||||
"Classes inherited from exception should have no custom fields/methods"
|
||||
.into(),
|
||||
]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// deal with ancestor of Exception object
|
||||
let TopLevelDef::Class { name, ancestors, object_id, .. } =
|
||||
&mut *self.definition_ast_list[7].0.write()
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
assert_eq!(*name, "Exception".into());
|
||||
ancestors.push(make_self_type_annotation(&[], *object_id));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// step 3, class fields and methods
|
||||
fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), HashSet<String>> {
|
||||
let temp_def_list = self.extract_def_list();
|
||||
let primitives = &self.primitives_ty;
|
||||
let def_ast_list = &self.definition_ast_list;
|
||||
let unifier = self.unifier.borrow_mut();
|
||||
let primitives_store = &self.primitives_ty;
|
||||
|
||||
let mut errors: HashSet<String> = HashSet::new();
|
||||
let mut type_var_to_concrete_def: HashMap<Type, TypeAnnotation> = HashMap::new();
|
||||
|
||||
let mut errors = HashSet::new();
|
||||
for (class_def, class_ast) in def_ast_list.iter().skip(self.builtin_num) {
|
||||
if class_ast.is_none() {
|
||||
continue;
|
||||
}
|
||||
if matches!(&*class_def.read(), TopLevelDef::Class { .. }) {
|
||||
for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) {
|
||||
if class_ast.is_some() && matches!(&*class_def.read(), TopLevelDef::Class { .. }) {
|
||||
if let Err(e) = Self::analyze_single_class_methods_fields(
|
||||
class_def,
|
||||
&class_ast.as_ref().unwrap().node,
|
||||
&temp_def_list,
|
||||
unifier,
|
||||
primitives,
|
||||
primitives_store,
|
||||
&mut type_var_to_concrete_def,
|
||||
(&self.keyword_list, &self.core_config),
|
||||
) {
|
||||
errors.extend(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
if !errors.is_empty() {
|
||||
return Err(errors);
|
||||
}
|
||||
|
||||
// handle the inherited methods and fields
|
||||
// Note: we cannot defer error handling til the end of the loop, because there is loop
|
||||
// carried dependency, ignoring the error (temporarily) will cause all assumptions to break
|
||||
// and produce weird error messages
|
||||
let mut current_ancestor_depth: usize = 2;
|
||||
loop {
|
||||
let mut finished = true;
|
||||
|
||||
for (class_def, class_ast) in def_ast_list.iter().skip(self.builtin_num) {
|
||||
if class_ast.is_none() {
|
||||
continue;
|
||||
// The errors need to be reported before copying methods from parent to child classes
|
||||
if !errors.is_empty() {
|
||||
return Err(errors);
|
||||
}
|
||||
let mut class_def = class_def.write();
|
||||
if let TopLevelDef::Class { ancestors, .. } = &*class_def {
|
||||
// if the length of the ancestor is equal to the current depth
|
||||
// it means that all the ancestors of the class is handled
|
||||
if ancestors.len() == current_ancestor_depth {
|
||||
finished = false;
|
||||
Self::analyze_single_class_ancestors(
|
||||
|
||||
// The lock on `class_def` must be released once the ancestors are updated
|
||||
{
|
||||
let mut class_def = class_def.write();
|
||||
let TopLevelDef::Class { ancestors, .. } = &*class_def else { unreachable!() };
|
||||
// Methods/fields needs to be processed only if class inherits from another class
|
||||
if ancestors.len() > 1 {
|
||||
if let Err(e) = Self::analyze_single_class_ancestors(
|
||||
&mut class_def,
|
||||
&temp_def_list,
|
||||
unifier,
|
||||
primitives,
|
||||
primitives_store,
|
||||
&mut type_var_to_concrete_def,
|
||||
)?;
|
||||
) {
|
||||
errors.extend(e);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
let mut subst_list = Some(Vec::new());
|
||||
// unification of previously assigned typevar
|
||||
let mut unification_helper = |ty, def| -> Result<(), HashSet<String>> {
|
||||
let target_ty = get_type_from_type_annotation_kinds(
|
||||
&temp_def_list,
|
||||
unifier,
|
||||
primitives_store,
|
||||
&def,
|
||||
&mut subst_list,
|
||||
)?;
|
||||
unifier
|
||||
.unify(ty, target_ty)
|
||||
.map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?;
|
||||
Ok(())
|
||||
};
|
||||
for (ty, def) in &type_var_to_concrete_def {
|
||||
if let Err(e) = unification_helper(*ty, def.clone()) {
|
||||
errors.extend(e);
|
||||
}
|
||||
}
|
||||
for ty in subst_list.unwrap() {
|
||||
let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let mut new_fields = HashMap::new();
|
||||
let mut need_subst = false;
|
||||
for (name, (ty, mutable)) in fields {
|
||||
let substituted = unifier.subst(*ty, params);
|
||||
need_subst |= substituted.is_some();
|
||||
new_fields.insert(*name, (substituted.unwrap_or(*ty), *mutable));
|
||||
}
|
||||
if need_subst {
|
||||
let new_ty = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: *obj_id,
|
||||
params: params.clone(),
|
||||
fields: new_fields,
|
||||
});
|
||||
if let Err(e) = unifier.unify(ty, new_ty) {
|
||||
errors.insert(e.to_display(unifier).to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if finished {
|
||||
break;
|
||||
}
|
||||
|
||||
current_ancestor_depth += 1;
|
||||
if current_ancestor_depth > def_ast_list.len() + 1 {
|
||||
unreachable!("cannot be longer than the whole top level def list")
|
||||
}
|
||||
}
|
||||
|
||||
let mut subst_list = Some(Vec::new());
|
||||
// unification of previously assigned typevar
|
||||
let mut unification_helper = |ty, def| -> Result<(), HashSet<String>> {
|
||||
let target_ty = get_type_from_type_annotation_kinds(
|
||||
&temp_def_list,
|
||||
unifier,
|
||||
primitives,
|
||||
&def,
|
||||
&mut subst_list,
|
||||
)?;
|
||||
unifier
|
||||
.unify(ty, target_ty)
|
||||
.map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?;
|
||||
Ok(())
|
||||
};
|
||||
for (ty, def) in type_var_to_concrete_def {
|
||||
if let Err(e) = unification_helper(ty, def) {
|
||||
errors.extend(e);
|
||||
}
|
||||
}
|
||||
for ty in subst_list.unwrap() {
|
||||
let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let mut new_fields = HashMap::new();
|
||||
let mut need_subst = false;
|
||||
for (name, (ty, mutable)) in fields {
|
||||
let substituted = unifier.subst(*ty, params);
|
||||
need_subst |= substituted.is_some();
|
||||
new_fields.insert(*name, (substituted.unwrap_or(*ty), *mutable));
|
||||
}
|
||||
if need_subst {
|
||||
let new_ty = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: *obj_id,
|
||||
params: params.clone(),
|
||||
fields: new_fields,
|
||||
});
|
||||
if let Err(e) = unifier.unify(ty, new_ty) {
|
||||
errors.insert(e.to_display(unifier).to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
if !errors.is_empty() {
|
||||
return Err(errors);
|
||||
}
|
||||
|
||||
for (def, _) in def_ast_list.iter().skip(self.builtin_num) {
|
||||
for (def, _) in def_list.iter().skip(self.builtin_num) {
|
||||
match &*def.read() {
|
||||
TopLevelDef::Class { resolver: Some(resolver), .. }
|
||||
| TopLevelDef::Function { resolver: Some(resolver), .. } => {
|
||||
if let Err(e) =
|
||||
resolver.handle_deferred_eval(unifier, &temp_def_list, primitives)
|
||||
resolver.handle_deferred_eval(unifier, &temp_def_list, primitives_store)
|
||||
{
|
||||
errors.insert(e);
|
||||
}
|
||||
@ -895,10 +638,13 @@ impl TopLevelComposer {
|
||||
}
|
||||
}
|
||||
|
||||
if !errors.is_empty() {
|
||||
return Err(errors);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// step 4, after class methods are done, top level functions have nothing unknown
|
||||
/// step 3, after class methods are done, top level functions have nothing unknown
|
||||
fn analyze_top_level_function(&mut self) -> Result<(), HashSet<String>> {
|
||||
let def_list = &self.definition_ast_list;
|
||||
let keyword_list = &self.keyword_list;
|
||||
@ -1253,126 +999,83 @@ impl TopLevelComposer {
|
||||
let mut method_var_map = VarMap::new();
|
||||
|
||||
let arg_types: Vec<FuncArg> = {
|
||||
// check method parameters cannot have same name
|
||||
// Function arguments must have:
|
||||
// 1) `self` as first argument (we currently do not support staticmethods)
|
||||
// 2) unique names
|
||||
// 3) names different than keywords
|
||||
match args.args.first() {
|
||||
Some(id) if id.node.arg == "self".into() => {},
|
||||
_ => return Err(HashSet::from([format!(
|
||||
"{name} method must have a `self` parameter (at {})", b.location
|
||||
)])),
|
||||
}
|
||||
let mut defined_parameter_name: HashSet<_> = HashSet::new();
|
||||
let zelf: StrRef = "self".into();
|
||||
for x in &args.args {
|
||||
if !defined_parameter_name.insert(x.node.arg)
|
||||
|| (keyword_list.contains(&x.node.arg) && x.node.arg != zelf)
|
||||
{
|
||||
return Err(HashSet::from([
|
||||
format!("top level function must have unique parameter names \
|
||||
and names should not be the same as the keywords (at {})",
|
||||
x.location),
|
||||
]))
|
||||
for arg in args.args.iter().skip(1) {
|
||||
if !defined_parameter_name.insert(arg.node.arg) {
|
||||
return Err(HashSet::from([format!("class method must have a unique parameter names (at {})", b.location)]));
|
||||
}
|
||||
if keyword_list.contains(&arg.node.arg) {
|
||||
return Err(HashSet::from([format!("parameter names should not be the same as the keywords (at {})", b.location)]));
|
||||
}
|
||||
}
|
||||
|
||||
if name == &"__init__".into() && !defined_parameter_name.contains(&zelf) {
|
||||
return Err(HashSet::from([
|
||||
format!("__init__ method must have a `self` parameter (at {})", b.location),
|
||||
]))
|
||||
// `self` must not be provided type annotation or default value
|
||||
if args.args.len() == args.defaults.len() {
|
||||
return Err(HashSet::from([format!("`self` cannot have a default value (at {})", b.location)]));
|
||||
}
|
||||
if !defined_parameter_name.contains(&zelf) {
|
||||
return Err(HashSet::from([
|
||||
format!("class method must have a `self` parameter (at {})", b.location),
|
||||
]))
|
||||
if args.args[0].node.annotation.is_some() {
|
||||
return Err(HashSet::from([format!("`self` cannot have a type annotation (at {})", b.location)]));
|
||||
}
|
||||
|
||||
let mut result = Vec::new();
|
||||
|
||||
let arg_with_default: Vec<(
|
||||
&ast::Located<ast::ArgData<()>>,
|
||||
Option<&ast::Expr>,
|
||||
)> = args
|
||||
.args
|
||||
.iter()
|
||||
.rev()
|
||||
.zip(
|
||||
args.defaults
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|x| -> Option<&ast::Expr> { Some(x) })
|
||||
.chain(std::iter::repeat(None)),
|
||||
)
|
||||
.collect_vec();
|
||||
|
||||
for (x, default) in arg_with_default.into_iter().rev() {
|
||||
let name = x.node.arg;
|
||||
if name != zelf {
|
||||
let type_ann = {
|
||||
let annotation_expr = x
|
||||
.node
|
||||
.annotation
|
||||
.as_ref()
|
||||
.ok_or_else(|| HashSet::from([
|
||||
format!(
|
||||
"type annotation needed for `{}` at {}",
|
||||
x.node.arg, x.location
|
||||
),
|
||||
]))?
|
||||
.as_ref();
|
||||
parse_ast_to_type_annotation_kinds(
|
||||
class_resolver,
|
||||
temp_def_list,
|
||||
unifier,
|
||||
primitives,
|
||||
annotation_expr,
|
||||
vec![(class_id, class_type_vars_def.clone())]
|
||||
.into_iter()
|
||||
.collect::<HashMap<_, _>>(),
|
||||
)?
|
||||
let no_defaults = args.args.len() - args.defaults.len() - 1;
|
||||
for (idx, x) in itertools::enumerate(args.args.iter().skip(1)) {
|
||||
let type_ann = {
|
||||
let Some(annotation_expr) = x.node.annotation.as_ref() else {return Err(HashSet::from([format!("type annotation needed for `{}` (at {})", x.node.arg, x.location)]));};
|
||||
parse_ast_to_type_annotation_kinds(
|
||||
class_resolver,
|
||||
temp_def_list,
|
||||
unifier,
|
||||
primitives,
|
||||
annotation_expr,
|
||||
vec![(class_id, class_type_vars_def.clone())]
|
||||
.into_iter()
|
||||
.collect::<HashMap<_, _>>(),
|
||||
)?
|
||||
};
|
||||
// find type vars within this method parameter type annotation
|
||||
let type_vars_within = get_type_var_contained_in_type_annotation(&type_ann);
|
||||
// handle the class type var and the method type var
|
||||
for type_var_within in type_vars_within {
|
||||
let TypeAnnotation::TypeVar(ty) = type_var_within else {
|
||||
unreachable!("must be type var annotation")
|
||||
};
|
||||
// find type vars within this method parameter type annotation
|
||||
let type_vars_within =
|
||||
get_type_var_contained_in_type_annotation(&type_ann);
|
||||
// handle the class type var and the method type var
|
||||
for type_var_within in type_vars_within {
|
||||
let TypeAnnotation::TypeVar(ty) = type_var_within else {
|
||||
unreachable!("must be type var annotation")
|
||||
};
|
||||
|
||||
let id = Self::get_var_id(ty, unifier)?;
|
||||
if let Some(prev_ty) = method_var_map.insert(id, ty) {
|
||||
// if already in the list, make sure they are the same?
|
||||
assert_eq!(prev_ty, ty);
|
||||
}
|
||||
let id = Self::get_var_id(ty, unifier)?;
|
||||
if let Some(prev_ty) = method_var_map.insert(id, ty) {
|
||||
// if already in the list, make sure they are the same?
|
||||
assert_eq!(prev_ty, ty);
|
||||
}
|
||||
// finish handling type vars
|
||||
let dummy_func_arg = FuncArg {
|
||||
name,
|
||||
ty: unifier.get_dummy_var().ty,
|
||||
default_value: match default {
|
||||
None => None,
|
||||
Some(default) => {
|
||||
if name == "self".into() {
|
||||
return Err(HashSet::from([
|
||||
format!("`self` parameter cannot take default value (at {})", x.location),
|
||||
]));
|
||||
}
|
||||
Some({
|
||||
let v = Self::parse_parameter_default_value(
|
||||
default,
|
||||
class_resolver,
|
||||
)?;
|
||||
Self::check_default_param_type(
|
||||
&v, &type_ann, primitives, unifier,
|
||||
)
|
||||
.map_err(|err| HashSet::from([
|
||||
format!("{} (at {})", err, x.location),
|
||||
]))?;
|
||||
v
|
||||
})
|
||||
}
|
||||
},
|
||||
is_vararg: false,
|
||||
};
|
||||
// push the dummy type and the type annotation
|
||||
// into the list for later unification
|
||||
type_var_to_concrete_def
|
||||
.insert(dummy_func_arg.ty, type_ann.clone());
|
||||
result.push(dummy_func_arg);
|
||||
}
|
||||
// finish handling type vars
|
||||
let dummy_func_arg = FuncArg {
|
||||
name: x.node.arg,
|
||||
ty: unifier.get_dummy_var().ty,
|
||||
default_value: if idx < no_defaults { None } else {
|
||||
let default_idx = idx - no_defaults;
|
||||
|
||||
Some({
|
||||
let v = Self::parse_parameter_default_value(&args.defaults[default_idx], class_resolver)?;
|
||||
Self::check_default_param_type(&v, &type_ann, primitives, unifier).map_err(|err| HashSet::from([format!("{} (at {})", err, x.location)]))?;
|
||||
v
|
||||
})
|
||||
},
|
||||
is_vararg: false,
|
||||
};
|
||||
// push the dummy type and the type annotation
|
||||
// into the list for later unification
|
||||
type_var_to_concrete_def
|
||||
.insert(dummy_func_arg.ty, type_ann.clone());
|
||||
result.push(dummy_func_arg);
|
||||
}
|
||||
result
|
||||
};
|
||||
@ -1494,12 +1197,12 @@ impl TopLevelComposer {
|
||||
match v {
|
||||
ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {}
|
||||
_ => {
|
||||
return Err(HashSet::from([
|
||||
format!(
|
||||
"unsupported statement in class definition body (at {})",
|
||||
b.location
|
||||
),
|
||||
]))
|
||||
return Err(HashSet::from([
|
||||
format!(
|
||||
"unsupported statement in class definition body (at {})",
|
||||
b.location
|
||||
),
|
||||
]))
|
||||
}
|
||||
}
|
||||
class_attributes_def.push((*attr, dummy_field_type, v.clone()));
|
||||
@ -1535,7 +1238,7 @@ impl TopLevelComposer {
|
||||
unreachable!("must be type var annotation")
|
||||
};
|
||||
|
||||
if !class_type_vars_def.contains(&t) {
|
||||
if !class_type_vars_def.contains(&t){
|
||||
return Err(HashSet::from([
|
||||
format!(
|
||||
"class fields can only use type \
|
||||
@ -1569,7 +1272,7 @@ impl TopLevelComposer {
|
||||
_ => {
|
||||
return Err(HashSet::from([
|
||||
format!(
|
||||
"unsupported statement in class definition body (at {})",
|
||||
"unsupported statement type in class definition body (at {})",
|
||||
b.location
|
||||
),
|
||||
]))
|
||||
@ -1615,7 +1318,6 @@ impl TopLevelComposer {
|
||||
let TypeAnnotation::CustomClass { id, params: _ } = base else {
|
||||
unreachable!("must be class type annotation")
|
||||
};
|
||||
|
||||
let base = temp_def_list.get(id.0).unwrap();
|
||||
let base = base.read();
|
||||
let TopLevelDef::Class { methods, fields, attributes, .. } = &*base else {
|
||||
@ -1624,93 +1326,68 @@ impl TopLevelComposer {
|
||||
|
||||
// handle methods override
|
||||
// since we need to maintain the order, create a new list
|
||||
let mut new_child_methods: Vec<(StrRef, Type, DefinitionId)> = Vec::new();
|
||||
let mut is_override: HashSet<StrRef> = HashSet::new();
|
||||
for (anc_method_name, anc_method_ty, anc_method_def_id) in methods {
|
||||
// find if there is a method with same name in the child class
|
||||
let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id);
|
||||
for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def {
|
||||
if class_method_name == anc_method_name {
|
||||
// ignore and handle self
|
||||
// if is __init__ method, no need to check return type
|
||||
let ok = class_method_name == &"__init__".into()
|
||||
|| Self::check_overload_function_type(
|
||||
*class_method_ty,
|
||||
*anc_method_ty,
|
||||
unifier,
|
||||
type_var_to_concrete_def,
|
||||
);
|
||||
if !ok {
|
||||
return Err(HashSet::from([format!(
|
||||
"method {class_method_name} has same name as ancestors' method, but incompatible type"),
|
||||
]));
|
||||
}
|
||||
// mark it as added
|
||||
is_override.insert(*class_method_name);
|
||||
to_be_added = (*class_method_name, *class_method_ty, *class_method_defid);
|
||||
break;
|
||||
let mut new_child_methods: IndexMap<StrRef, (Type, DefinitionId)> =
|
||||
methods.iter().map(|m| (m.0, (m.1, m.2))).collect();
|
||||
|
||||
// let mut new_child_methods: Vec<(StrRef, Type, DefinitionId)> = methods.clone();
|
||||
for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def {
|
||||
if let Some((ty, _)) = new_child_methods
|
||||
.insert(*class_method_name, (*class_method_ty, *class_method_defid))
|
||||
{
|
||||
let ok = class_method_name == &"__init__".into()
|
||||
|| Self::check_overload_function_type(
|
||||
*class_method_ty,
|
||||
ty,
|
||||
unifier,
|
||||
type_var_to_concrete_def,
|
||||
);
|
||||
if !ok {
|
||||
return Err(HashSet::from([format!(
|
||||
"method {class_method_name} has same name as ancestors' method, but incompatible type"),
|
||||
]));
|
||||
}
|
||||
}
|
||||
new_child_methods.push(to_be_added);
|
||||
}
|
||||
// add those that are not overriding method to the new_child_methods
|
||||
for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def {
|
||||
if !is_override.contains(class_method_name) {
|
||||
new_child_methods.push((*class_method_name, *class_method_ty, *class_method_defid));
|
||||
}
|
||||
}
|
||||
// use the new_child_methods to replace all the elements in `class_methods_def`
|
||||
class_methods_def.clear();
|
||||
class_methods_def.extend(new_child_methods);
|
||||
class_methods_def
|
||||
.extend(new_child_methods.iter().map(|f| (*f.0, f.1 .0, f.1 .1)).collect_vec());
|
||||
|
||||
// handle class fields
|
||||
let mut new_child_fields: Vec<(StrRef, Type, bool)> = Vec::new();
|
||||
// let mut is_override: HashSet<_> = HashSet::new();
|
||||
for (anc_field_name, anc_field_ty, mutable) in fields {
|
||||
let to_be_added = (*anc_field_name, *anc_field_ty, *mutable);
|
||||
// find if there is a fields with the same name in the child class
|
||||
for (class_field_name, ..) in &*class_fields_def {
|
||||
if class_field_name == anc_field_name
|
||||
|| attributes.iter().any(|f| f.0 == *class_field_name)
|
||||
{
|
||||
return Err(HashSet::from([format!(
|
||||
"field `{class_field_name}` has already declared in the ancestor classes"
|
||||
)]));
|
||||
}
|
||||
let mut new_child_fields: IndexMap<StrRef, (Type, bool)> =
|
||||
fields.iter().map(|f| (f.0, (f.1, f.2))).collect();
|
||||
let mut new_child_attributes: IndexMap<StrRef, (Type, ast::Constant)> =
|
||||
attributes.iter().map(|f| (f.0, (f.1, f.2.clone()))).collect();
|
||||
// Overriding class fields and attributes is currently not supported
|
||||
for (name, ty, mutable) in &*class_fields_def {
|
||||
if new_child_fields.insert(*name, (*ty, *mutable)).is_some()
|
||||
|| new_child_attributes.contains_key(name)
|
||||
{
|
||||
return Err(HashSet::from([format!(
|
||||
"field `{name}` has already declared in the ancestor classes"
|
||||
)]));
|
||||
}
|
||||
}
|
||||
for (name, ty, val) in &*class_attribute_def {
|
||||
if new_child_attributes.insert(*name, (*ty, val.clone())).is_some()
|
||||
|| new_child_fields.contains_key(name)
|
||||
{
|
||||
return Err(HashSet::from([format!(
|
||||
"attribute `{name}` has already declared in the ancestor classes"
|
||||
)]));
|
||||
}
|
||||
new_child_fields.push(to_be_added);
|
||||
}
|
||||
|
||||
// handle class attributes
|
||||
let mut new_child_attributes: Vec<(StrRef, Type, ast::Constant)> = Vec::new();
|
||||
for (anc_attr_name, anc_attr_ty, attr_value) in attributes {
|
||||
let to_be_added = (*anc_attr_name, *anc_attr_ty, attr_value.clone());
|
||||
// find if there is a attribute with the same name in the child class
|
||||
for (class_attr_name, ..) in &*class_attribute_def {
|
||||
if class_attr_name == anc_attr_name
|
||||
|| fields.iter().any(|f| f.0 == *class_attr_name)
|
||||
{
|
||||
return Err(HashSet::from([format!(
|
||||
"attribute `{class_attr_name}` has already declared in the ancestor classes"
|
||||
)]));
|
||||
}
|
||||
}
|
||||
new_child_attributes.push(to_be_added);
|
||||
}
|
||||
|
||||
for (class_field_name, class_field_ty, mutable) in &*class_fields_def {
|
||||
if !is_override.contains(class_field_name) {
|
||||
new_child_fields.push((*class_field_name, *class_field_ty, *mutable));
|
||||
}
|
||||
}
|
||||
class_fields_def.clear();
|
||||
class_fields_def.extend(new_child_fields);
|
||||
class_fields_def
|
||||
.extend(new_child_fields.iter().map(|f| (*f.0, f.1 .0, f.1 .1)).collect_vec());
|
||||
class_attribute_def.clear();
|
||||
class_attribute_def.extend(new_child_attributes);
|
||||
class_attribute_def.extend(
|
||||
new_child_attributes.iter().map(|f| (*f.0, f.1 .0, f.1 .1.clone())).collect_vec(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of
|
||||
/// step 4, analyze and call type inferencer to fill the `instance_to_stmt` of
|
||||
/// [`TopLevelDef::Function`]
|
||||
fn analyze_function_instance(&mut self) -> Result<(), HashSet<String>> {
|
||||
// first get the class constructor type correct for the following type check in function body
|
||||
@ -2229,7 +1906,7 @@ impl TopLevelComposer {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Step 6. Analyze and populate the types of global variables.
|
||||
/// Step 5. Analyze and populate the types of global variables.
|
||||
fn analyze_top_level_variables(&mut self) -> Result<(), HashSet<String>> {
|
||||
let def_list = &self.definition_ast_list;
|
||||
let temp_def_list = self.extract_def_list();
|
||||
|
@ -626,64 +626,6 @@ impl TopLevelComposer {
|
||||
Err(HashSet::from([format!("no method {method_name} in the current class")]))
|
||||
}
|
||||
|
||||
/// get all base class def id of a class, excluding itself. \
|
||||
/// this function should called only after the direct parent is set
|
||||
/// and before all the ancestors are set
|
||||
/// and when we allow single inheritance \
|
||||
/// the order of the returned list is from the child to the deepest ancestor
|
||||
pub fn get_all_ancestors_helper(
|
||||
child: &TypeAnnotation,
|
||||
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
|
||||
) -> Result<Vec<TypeAnnotation>, HashSet<String>> {
|
||||
let mut result: Vec<TypeAnnotation> = Vec::new();
|
||||
let mut parent = Self::get_parent(child, temp_def_list);
|
||||
while let Some(p) = parent {
|
||||
parent = Self::get_parent(&p, temp_def_list);
|
||||
let p_id = if let TypeAnnotation::CustomClass { id, .. } = &p {
|
||||
*id
|
||||
} else {
|
||||
unreachable!("must be class kind annotation")
|
||||
};
|
||||
// check cycle
|
||||
let no_cycle = result.iter().all(|x| {
|
||||
let TypeAnnotation::CustomClass { id, .. } = x else {
|
||||
unreachable!("must be class kind annotation")
|
||||
};
|
||||
|
||||
id.0 != p_id.0
|
||||
});
|
||||
if no_cycle {
|
||||
result.push(p);
|
||||
} else {
|
||||
return Err(HashSet::from(["cyclic inheritance detected".into()]));
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// should only be called when finding all ancestors, so panic when wrong
|
||||
fn get_parent(
|
||||
child: &TypeAnnotation,
|
||||
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
|
||||
) -> Option<TypeAnnotation> {
|
||||
let child_id = if let TypeAnnotation::CustomClass { id, .. } = child {
|
||||
*id
|
||||
} else {
|
||||
unreachable!("should be class type annotation")
|
||||
};
|
||||
let child_def = temp_def_list.get(child_id.0).unwrap();
|
||||
let child_def = child_def.read();
|
||||
let TopLevelDef::Class { ancestors, .. } = &*child_def else {
|
||||
unreachable!("child must be top level class def")
|
||||
};
|
||||
|
||||
if ancestors.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(ancestors[0].clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// get the `var_id` of a given `TVar` type
|
||||
pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<TypeVarId, HashSet<String>> {
|
||||
if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() {
|
||||
@ -993,6 +935,139 @@ impl TopLevelComposer {
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses the class type variables and direct parents
|
||||
/// we only allow single inheritance
|
||||
pub fn analyze_class_bases(
|
||||
class_def: &Arc<RwLock<TopLevelDef>>,
|
||||
class_ast: &Option<Stmt>,
|
||||
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
|
||||
unifier: &mut Unifier,
|
||||
primitives_store: &PrimitiveStore,
|
||||
) -> Result<(), HashSet<String>> {
|
||||
let mut class_def = class_def.write();
|
||||
let (class_def_id, class_ancestors, class_bases_ast, class_type_vars, class_resolver) = {
|
||||
let TopLevelDef::Class { object_id, ancestors, type_vars, resolver, .. } =
|
||||
&mut *class_def
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) = class_ast
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
(object_id, ancestors, bases, type_vars, resolver.as_ref().unwrap().as_ref())
|
||||
};
|
||||
|
||||
let mut is_generic = false;
|
||||
let mut has_base = false;
|
||||
// Check class bases for typevars
|
||||
for b in class_bases_ast {
|
||||
match &b.node {
|
||||
// analyze typevars bounded to the class,
|
||||
// only support things like `class A(Generic[T, V])`,
|
||||
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported
|
||||
// i.e. only simple names are allowed in the subscript
|
||||
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
|
||||
ast::ExprKind::Subscript { value, slice, .. } if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Generic".into()) =>
|
||||
{
|
||||
if is_generic {
|
||||
return Err(HashSet::from([format!(
|
||||
"only single Generic[...] is allowed (at {})",
|
||||
b.location
|
||||
)]));
|
||||
}
|
||||
is_generic = true;
|
||||
|
||||
let type_var_list: Vec<&ast::Expr<()>>;
|
||||
// if `class A(Generic[T, V, G])`
|
||||
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
|
||||
type_var_list = elts.iter().collect_vec();
|
||||
// `class A(Generic[T])`
|
||||
} else {
|
||||
type_var_list = vec![&**slice];
|
||||
}
|
||||
|
||||
let type_vars = type_var_list
|
||||
.into_iter()
|
||||
.map(|e| {
|
||||
class_resolver.parse_type_annotation(
|
||||
temp_def_list,
|
||||
unifier,
|
||||
primitives_store,
|
||||
e,
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
class_type_vars.extend(type_vars);
|
||||
}
|
||||
ast::ExprKind::Name { .. } | ast::ExprKind::Subscript { .. } => {
|
||||
if has_base {
|
||||
return Err(HashSet::from([format!("a class definition can only have at most one base class declaration and one generic declaration (at {})", b.location )]));
|
||||
}
|
||||
has_base = true;
|
||||
// the function parse_ast_to make sure that no type var occurred in
|
||||
// bast_ty if it is a CustomClassKind
|
||||
let base_ty = parse_ast_to_type_annotation_kinds(
|
||||
class_resolver,
|
||||
temp_def_list,
|
||||
unifier,
|
||||
primitives_store,
|
||||
b,
|
||||
vec![(*class_def_id, class_type_vars.clone())]
|
||||
.into_iter()
|
||||
.collect::<HashMap<_, _>>(),
|
||||
)?;
|
||||
if let TypeAnnotation::CustomClass { .. } = &base_ty {
|
||||
class_ancestors.push(base_ty);
|
||||
} else {
|
||||
return Err(HashSet::from([format!(
|
||||
"class base declaration can only be custom class (at {})",
|
||||
b.location
|
||||
)]));
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
return Err(HashSet::from([format!(
|
||||
"unsupported statement in class defintion (at {})",
|
||||
b.location
|
||||
)]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// gets all ancestors of a class
|
||||
pub fn analyze_class_ancestors(
|
||||
class_def: &Arc<RwLock<TopLevelDef>>,
|
||||
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
|
||||
) {
|
||||
// Check if class has a direct parent
|
||||
let mut class_def = class_def.write();
|
||||
let TopLevelDef::Class { ancestors, type_vars, object_id, .. } = &mut *class_def else {
|
||||
unreachable!()
|
||||
};
|
||||
let mut anc_set = HashMap::new();
|
||||
|
||||
if let Some(ancestor) = ancestors.first() {
|
||||
let TypeAnnotation::CustomClass { id, .. } = ancestor else { unreachable!() };
|
||||
let TopLevelDef::Class { ancestors: parent_ancestors, .. } =
|
||||
&*temp_def_list[id.0].read()
|
||||
else {
|
||||
unreachable!()
|
||||
};
|
||||
for anc in parent_ancestors.iter().skip(1) {
|
||||
let TypeAnnotation::CustomClass { id, .. } = anc else { unreachable!() };
|
||||
anc_set.insert(id, anc.clone());
|
||||
}
|
||||
ancestors.extend(anc_set.into_values());
|
||||
}
|
||||
// push `self` as first ancestor of class
|
||||
ancestors.insert(0, make_self_type_annotation(type_vars.as_slice(), *object_id));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_parameter_default_value(
|
||||
|
@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs
|
||||
expression: res_vec
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(241)]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
|
||||
]
|
||||
|
@ -7,11 +7,11 @@ expression: res_vec
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n",
|
||||
]
|
||||
|
@ -230,11 +230,6 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
|
||||
def foo(self, a: T, b: V):
|
||||
pass
|
||||
"},
|
||||
indoc! {"
|
||||
class B(C):
|
||||
def __init__(self):
|
||||
pass
|
||||
"},
|
||||
indoc! {"
|
||||
class C(A):
|
||||
def __init__(self):
|
||||
@ -243,6 +238,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
|
||||
a = 1
|
||||
pass
|
||||
"},
|
||||
indoc! {"
|
||||
class B(C):
|
||||
def __init__(self):
|
||||
pass
|
||||
"},
|
||||
indoc! {"
|
||||
def foo(a: A):
|
||||
pass
|
||||
@ -257,6 +257,14 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
|
||||
)]
|
||||
#[test_case(
|
||||
&[
|
||||
indoc! {"
|
||||
class B:
|
||||
aa: bool
|
||||
def __init__(self):
|
||||
self.aa = False
|
||||
def foo(self, b: T):
|
||||
pass
|
||||
"},
|
||||
indoc! {"
|
||||
class Generic_A(Generic[V], B):
|
||||
a: int64
|
||||
@ -264,14 +272,6 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
|
||||
self.a = 123123123123
|
||||
def fun(self, a: int32) -> V:
|
||||
pass
|
||||
"},
|
||||
indoc! {"
|
||||
class B:
|
||||
aa: bool
|
||||
def __init__(self):
|
||||
self.aa = False
|
||||
def foo(self, b: T):
|
||||
pass
|
||||
"}
|
||||
],
|
||||
&[];
|
||||
@ -391,18 +391,18 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
|
||||
pass
|
||||
"}
|
||||
],
|
||||
&["cyclic inheritance detected"];
|
||||
&["NameError: name 'B' is not defined (at unknown:1:9)"];
|
||||
"cyclic1"
|
||||
)]
|
||||
#[test_case(
|
||||
&[
|
||||
indoc! {"
|
||||
class A(B[bool, int64]):
|
||||
def __init__(self):
|
||||
pass
|
||||
class B(Generic[V, T], C[int32]):
|
||||
def __init__(self):
|
||||
pass
|
||||
"},
|
||||
indoc! {"
|
||||
class B(Generic[V, T], C[int32]):
|
||||
class A(B[bool, int64]):
|
||||
def __init__(self):
|
||||
pass
|
||||
"},
|
||||
@ -412,7 +412,7 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
|
||||
pass
|
||||
"},
|
||||
],
|
||||
&["cyclic inheritance detected"];
|
||||
&["NameError: name 'C' is not defined (at unknown:1:25)"];
|
||||
"cyclic2"
|
||||
)]
|
||||
#[test_case(
|
||||
@ -436,11 +436,6 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
|
||||
)]
|
||||
#[test_case(
|
||||
&[
|
||||
indoc! {"
|
||||
class A(B, Generic[T], C):
|
||||
def __init__(self):
|
||||
pass
|
||||
"},
|
||||
indoc! {"
|
||||
class B:
|
||||
def __init__(self):
|
||||
@ -450,6 +445,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
|
||||
class C:
|
||||
def __init__(self):
|
||||
pass
|
||||
"},
|
||||
indoc! {"
|
||||
class A(B, Generic[T], C):
|
||||
def __init__(self):
|
||||
pass
|
||||
"}
|
||||
|
||||
],
|
||||
|
@ -101,7 +101,13 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
||||
Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
|
||||
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
|
||||
let type_vars = {
|
||||
let def_read = top_level_defs[obj_id.0].try_read();
|
||||
let Some(top_level_def) = top_level_defs.get(obj_id.0) else {
|
||||
return Err(HashSet::from([format!(
|
||||
"NameError: name '{id}' is not defined (at {})",
|
||||
expr.location
|
||||
)]));
|
||||
};
|
||||
let def_read = top_level_def.try_read();
|
||||
if let Some(def_read) = def_read {
|
||||
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
|
||||
type_vars.clone()
|
||||
@ -156,12 +162,17 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
||||
}
|
||||
let obj_id = resolver.get_identifier_def(*id)?;
|
||||
let type_vars = {
|
||||
let def_read = top_level_defs[obj_id.0].try_read();
|
||||
let Some(top_level_def) = top_level_defs.get(obj_id.0) else {
|
||||
return Err(HashSet::from([format!(
|
||||
"NameError: name '{id}' is not defined (at {})",
|
||||
expr.location
|
||||
)]));
|
||||
};
|
||||
let def_read = top_level_def.try_read();
|
||||
if let Some(def_read) = def_read {
|
||||
let TopLevelDef::Class { type_vars, .. } = &*def_read else {
|
||||
unreachable!("must be class here")
|
||||
};
|
||||
|
||||
type_vars.clone()
|
||||
} else {
|
||||
locked.get(&obj_id).unwrap().clone()
|
||||
|
Loading…
Reference in New Issue
Block a user