Compare commits

..

9 Commits

12 changed files with 1202 additions and 820 deletions

View File

@ -2886,7 +2886,31 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()), Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()),
None => { None => {
let resolver = ctx.resolver.clone(); let resolver = ctx.resolver.clone();
resolver.get_symbol_value(*id, ctx, generator).unwrap() let value = resolver.get_symbol_value(*id, ctx, generator).unwrap();
let globals = ctx
.top_level
.definitions
.read()
.iter()
.filter_map(|def| {
if let TopLevelDef::Variable { simple_name, ty, .. } = &*def.read() {
Some((*simple_name, *ty))
} else {
None
}
})
.collect_vec();
if let Some((_, ty)) = globals.iter().find(|(name, _)| name == id) {
let ptr = value
.to_basic_value_enum(ctx, generator, *ty)
.map(BasicValueEnum::into_pointer_value)?;
ctx.builder.build_load(ptr, id.to_string().as_str()).map(Into::into).unwrap()
} else {
value
}
} }
}, },
ExprKind::List { elts, .. } => { ExprKind::List { elts, .. } => {

File diff suppressed because it is too large Load Diff

View File

@ -600,7 +600,7 @@ impl TopLevelComposer {
name: String, name: String,
simple_name: StrRef, simple_name: StrRef,
ty: Type, ty: Type,
ty_decl: Expr, ty_decl: Option<Expr>,
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>, resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
loc: Option<Location>, loc: Option<Location>,
) -> TopLevelDef { ) -> TopLevelDef {
@ -626,6 +626,64 @@ impl TopLevelComposer {
Err(HashSet::from([format!("no method {method_name} in the current class")])) Err(HashSet::from([format!("no method {method_name} in the current class")]))
} }
/// get all base class def id of a class, excluding itself. \
/// this function should called only after the direct parent is set
/// and before all the ancestors are set
/// and when we allow single inheritance \
/// the order of the returned list is from the child to the deepest ancestor
pub fn get_all_ancestors_helper(
child: &TypeAnnotation,
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
) -> Result<Vec<TypeAnnotation>, HashSet<String>> {
let mut result: Vec<TypeAnnotation> = Vec::new();
let mut parent = Self::get_parent(child, temp_def_list);
while let Some(p) = parent {
parent = Self::get_parent(&p, temp_def_list);
let p_id = if let TypeAnnotation::CustomClass { id, .. } = &p {
*id
} else {
unreachable!("must be class kind annotation")
};
// check cycle
let no_cycle = result.iter().all(|x| {
let TypeAnnotation::CustomClass { id, .. } = x else {
unreachable!("must be class kind annotation")
};
id.0 != p_id.0
});
if no_cycle {
result.push(p);
} else {
return Err(HashSet::from(["cyclic inheritance detected".into()]));
}
}
Ok(result)
}
/// should only be called when finding all ancestors, so panic when wrong
fn get_parent(
child: &TypeAnnotation,
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
) -> Option<TypeAnnotation> {
let child_id = if let TypeAnnotation::CustomClass { id, .. } = child {
*id
} else {
unreachable!("should be class type annotation")
};
let child_def = temp_def_list.get(child_id.0).unwrap();
let child_def = child_def.read();
let TopLevelDef::Class { ancestors, .. } = &*child_def else {
unreachable!("child must be top level class def")
};
if ancestors.is_empty() {
None
} else {
Some(ancestors[0].clone())
}
}
/// get the `var_id` of a given `TVar` type /// get the `var_id` of a given `TVar` type
pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<TypeVarId, HashSet<String>> { pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result<TypeVarId, HashSet<String>> {
if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() {
@ -935,139 +993,6 @@ impl TopLevelComposer {
)) ))
} }
} }
/// Parses the class type variables and direct parents
/// we only allow single inheritance
pub fn analyze_class_bases(
class_def: &Arc<RwLock<TopLevelDef>>,
class_ast: &Option<Stmt>,
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
primitives_store: &PrimitiveStore,
) -> Result<(), HashSet<String>> {
let mut class_def = class_def.write();
let (class_def_id, class_ancestors, class_bases_ast, class_type_vars, class_resolver) = {
let TopLevelDef::Class { object_id, ancestors, type_vars, resolver, .. } =
&mut *class_def
else {
unreachable!()
};
let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) = class_ast
else {
unreachable!()
};
(object_id, ancestors, bases, type_vars, resolver.as_ref().unwrap().as_ref())
};
let mut is_generic = false;
let mut has_base = false;
// Check class bases for typevars
for b in class_bases_ast {
match &b.node {
// analyze typevars bounded to the class,
// only support things like `class A(Generic[T, V])`,
// things like `class A(Generic[T, V, ImportedModule.T])` is not supported
// i.e. only simple names are allowed in the subscript
// should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params
ast::ExprKind::Subscript { value, slice, .. } if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Generic".into()) =>
{
if is_generic {
return Err(HashSet::from([format!(
"only single Generic[...] is allowed (at {})",
b.location
)]));
}
is_generic = true;
let type_var_list: Vec<&ast::Expr<()>>;
// if `class A(Generic[T, V, G])`
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
type_var_list = elts.iter().collect_vec();
// `class A(Generic[T])`
} else {
type_var_list = vec![&**slice];
}
let type_vars = type_var_list
.into_iter()
.map(|e| {
class_resolver.parse_type_annotation(
temp_def_list,
unifier,
primitives_store,
e,
)
})
.collect::<Result<Vec<_>, _>>()?;
class_type_vars.extend(type_vars);
}
ast::ExprKind::Name { .. } | ast::ExprKind::Subscript { .. } => {
if has_base {
return Err(HashSet::from([format!("a class definition can only have at most one base class declaration and one generic declaration (at {})", b.location )]));
}
has_base = true;
// the function parse_ast_to make sure that no type var occurred in
// bast_ty if it is a CustomClassKind
let base_ty = parse_ast_to_type_annotation_kinds(
class_resolver,
temp_def_list,
unifier,
primitives_store,
b,
vec![(*class_def_id, class_type_vars.clone())]
.into_iter()
.collect::<HashMap<_, _>>(),
)?;
if let TypeAnnotation::CustomClass { .. } = &base_ty {
class_ancestors.push(base_ty);
} else {
return Err(HashSet::from([format!(
"class base declaration can only be custom class (at {})",
b.location
)]));
}
}
_ => {
return Err(HashSet::from([format!(
"unsupported statement in class defintion (at {})",
b.location
)]));
}
}
}
Ok(())
}
/// gets all ancestors of a class
pub fn analyze_class_ancestors(
class_def: &Arc<RwLock<TopLevelDef>>,
temp_def_list: &[Arc<RwLock<TopLevelDef>>],
) {
// Check if class has a direct parent
let mut class_def = class_def.write();
let TopLevelDef::Class { ancestors, type_vars, object_id, .. } = &mut *class_def else {
unreachable!()
};
let mut anc_set = HashMap::new();
if let Some(ancestor) = ancestors.first() {
let TypeAnnotation::CustomClass { id, .. } = ancestor else { unreachable!() };
let TopLevelDef::Class { ancestors: parent_ancestors, .. } =
&*temp_def_list[id.0].read()
else {
unreachable!()
};
for anc in parent_ancestors.iter().skip(1) {
let TypeAnnotation::CustomClass { id, .. } = anc else { unreachable!() };
anc_set.insert(id, anc.clone());
}
ancestors.extend(anc_set.into_values());
}
// push `self` as first ancestor of class
ancestors.insert(0, make_self_type_annotation(type_vars.as_slice(), *object_id));
}
} }
pub fn parse_parameter_default_value( pub fn parse_parameter_default_value(

View File

@ -158,8 +158,8 @@ pub enum TopLevelDef {
/// Type of the global variable. /// Type of the global variable.
ty: Type, ty: Type,
/// The declared type of the global variable. /// The declared type of the global variable, or [`None`] if no type annotation is provided.
ty_decl: Expr, ty_decl: Option<Expr>,
/// Symbol resolver of the module defined the class. /// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>, resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,

View File

@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(241)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
] ]

View File

@ -7,11 +7,11 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(249)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(257)]\n}\n",
] ]

View File

@ -230,6 +230,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
def foo(self, a: T, b: V): def foo(self, a: T, b: V):
pass pass
"}, "},
indoc! {"
class B(C):
def __init__(self):
pass
"},
indoc! {" indoc! {"
class C(A): class C(A):
def __init__(self): def __init__(self):
@ -238,11 +243,6 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
a = 1 a = 1
pass pass
"}, "},
indoc! {"
class B(C):
def __init__(self):
pass
"},
indoc! {" indoc! {"
def foo(a: A): def foo(a: A):
pass pass
@ -257,14 +257,6 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
)] )]
#[test_case( #[test_case(
&[ &[
indoc! {"
class B:
aa: bool
def __init__(self):
self.aa = False
def foo(self, b: T):
pass
"},
indoc! {" indoc! {"
class Generic_A(Generic[V], B): class Generic_A(Generic[V], B):
a: int64 a: int64
@ -272,6 +264,14 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
self.a = 123123123123 self.a = 123123123123
def fun(self, a: int32) -> V: def fun(self, a: int32) -> V:
pass pass
"},
indoc! {"
class B:
aa: bool
def __init__(self):
self.aa = False
def foo(self, b: T):
pass
"} "}
], ],
&[]; &[];
@ -391,18 +391,18 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass pass
"} "}
], ],
&["NameError: name 'B' is not defined (at unknown:1:9)"]; &["cyclic inheritance detected"];
"cyclic1" "cyclic1"
)] )]
#[test_case( #[test_case(
&[ &[
indoc! {" indoc! {"
class B(Generic[V, T], C[int32]): class A(B[bool, int64]):
def __init__(self): def __init__(self):
pass pass
"}, "},
indoc! {" indoc! {"
class A(B[bool, int64]): class B(Generic[V, T], C[int32]):
def __init__(self): def __init__(self):
pass pass
"}, "},
@ -412,7 +412,7 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
pass pass
"}, "},
], ],
&["NameError: name 'C' is not defined (at unknown:1:25)"]; &["cyclic inheritance detected"];
"cyclic2" "cyclic2"
)] )]
#[test_case( #[test_case(
@ -436,6 +436,11 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
)] )]
#[test_case( #[test_case(
&[ &[
indoc! {"
class A(B, Generic[T], C):
def __init__(self):
pass
"},
indoc! {" indoc! {"
class B: class B:
def __init__(self): def __init__(self):
@ -445,11 +450,6 @@ fn test_simple_function_analyze(source: &[&str], tys: &[&str], names: &[&str]) {
class C: class C:
def __init__(self): def __init__(self):
pass pass
"},
indoc! {"
class A(B, Generic[T], C):
def __init__(self):
pass
"} "}
], ],

View File

@ -101,13 +101,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() }) Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) { } else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
let type_vars = { let type_vars = {
let Some(top_level_def) = top_level_defs.get(obj_id.0) else { let def_read = top_level_defs[obj_id.0].try_read();
return Err(HashSet::from([format!(
"NameError: name '{id}' is not defined (at {})",
expr.location
)]));
};
let def_read = top_level_def.try_read();
if let Some(def_read) = def_read { if let Some(def_read) = def_read {
if let TopLevelDef::Class { type_vars, .. } = &*def_read { if let TopLevelDef::Class { type_vars, .. } = &*def_read {
type_vars.clone() type_vars.clone()
@ -162,17 +156,12 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
} }
let obj_id = resolver.get_identifier_def(*id)?; let obj_id = resolver.get_identifier_def(*id)?;
let type_vars = { let type_vars = {
let Some(top_level_def) = top_level_defs.get(obj_id.0) else { let def_read = top_level_defs[obj_id.0].try_read();
return Err(HashSet::from([format!(
"NameError: name '{id}' is not defined (at {})",
expr.location
)]));
};
let def_read = top_level_def.try_read();
if let Some(def_read) = def_read { if let Some(def_read) = def_read {
let TopLevelDef::Class { type_vars, .. } = &*def_read else { let TopLevelDef::Class { type_vars, .. } = &*def_read else {
unreachable!("must be class here") unreachable!("must be class here")
}; };
type_vars.clone() type_vars.clone()
} else { } else {
locked.get(&obj_id).unwrap().clone() locked.get(&obj_id).unwrap().clone()

View File

@ -10,7 +10,7 @@ use nac3parser::ast::{
}; };
use super::{ use super::{
type_inferencer::{IdentifierInfo, Inferencer}, type_inferencer::{DeclarationSource, IdentifierInfo, Inferencer},
typedef::{Type, TypeEnum}, typedef::{Type, TypeEnum},
}; };
use crate::toplevel::helper::PrimDef; use crate::toplevel::helper::PrimDef;
@ -34,6 +34,20 @@ impl<'a> Inferencer<'a> {
Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)])) Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)]))
} }
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
// If `id` refers to a declared symbol, reject this assignment if it is used in the
// context of an (implicit) global variable
if let Some(id_info) = defined_identifiers.get(id) {
if matches!(
id_info.source,
DeclarationSource::Global { is_explicit: Some(false) }
) {
return Err(HashSet::from([format!(
"cannot access local variable '{id}' before it is declared (at {})",
pattern.location
)]));
}
}
if !defined_identifiers.contains_key(id) { if !defined_identifiers.contains_key(id) {
defined_identifiers.insert(*id, IdentifierInfo::default()); defined_identifiers.insert(*id, IdentifierInfo::default());
} }
@ -104,7 +118,22 @@ impl<'a> Inferencer<'a> {
*id, *id,
) { ) {
Ok(_) => { Ok(_) => {
self.defined_identifiers.insert(*id, IdentifierInfo::default()); let is_global = self.is_id_global(*id);
defined_identifiers.insert(
*id,
IdentifierInfo {
source: match is_global {
Some(true) => {
DeclarationSource::Global { is_explicit: Some(false) }
}
Some(false) => {
DeclarationSource::Global { is_explicit: None }
}
None => DeclarationSource::Local,
},
},
);
} }
Err(e) => { Err(e) => {
return Err(HashSet::from([format!( return Err(HashSet::from([format!(
@ -368,9 +397,9 @@ impl<'a> Inferencer<'a> {
StmtKind::Global { names, .. } => { StmtKind::Global { names, .. } => {
for id in names { for id in names {
if let Some(id_info) = defined_identifiers.get(id) { if let Some(id_info) = defined_identifiers.get(id) {
if !id_info.is_global { if id_info.source == DeclarationSource::Local {
return Err(HashSet::from([format!( return Err(HashSet::from([format!(
"name '{id}' is assigned to before global declaration at {}", "name '{id}' is referenced prior to global declaration at {}",
stmt.location, stmt.location,
)])); )]));
} }
@ -385,8 +414,12 @@ impl<'a> Inferencer<'a> {
*id, *id,
) { ) {
Ok(_) => { Ok(_) => {
self.defined_identifiers defined_identifiers.insert(
.insert(*id, IdentifierInfo { is_global: true }); *id,
IdentifierInfo {
source: DeclarationSource::Global { is_explicit: Some(true) },
},
);
} }
Err(e) => { Err(e) => {
return Err(HashSet::from([format!( return Err(HashSet::from([format!(

View File

@ -12,7 +12,7 @@ use itertools::{izip, Itertools};
use nac3parser::ast::{ use nac3parser::ast::{
self, self,
fold::{self, Fold}, fold::{self, Fold},
Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef, Arguments, Comprehension, ExprContext, ExprKind, Ident, Located, Location, StrRef,
}; };
use super::{ use super::{
@ -88,11 +88,31 @@ impl PrimitiveStore {
} }
} }
/// The location where an identifier declaration refers to.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum DeclarationSource {
/// Local scope.
Local,
/// Global scope.
Global {
/// Whether the identifier is declared by the use of `global` statement. This field is
/// [`None`] if the identifier does not refer to a variable.
is_explicit: Option<bool>,
},
}
/// Information regarding a defined identifier. /// Information regarding a defined identifier.
#[derive(Clone, Copy, Debug, Default)] #[derive(Clone, Copy, Debug)]
pub struct IdentifierInfo { pub struct IdentifierInfo {
/// Whether this identifier refers to a global variable. /// Whether this identifier refers to a global variable.
pub is_global: bool, pub source: DeclarationSource,
}
impl Default for IdentifierInfo {
fn default() -> Self {
IdentifierInfo { source: DeclarationSource::Local }
}
} }
impl IdentifierInfo { impl IdentifierInfo {
@ -574,7 +594,22 @@ impl<'a> Fold<()> for Inferencer<'a> {
*id, *id,
) { ) {
Ok(_) => { Ok(_) => {
self.defined_identifiers.insert(*id, IdentifierInfo::default()); let is_global = self.is_id_global(*id);
self.defined_identifiers.insert(
*id,
IdentifierInfo {
source: match is_global {
Some(true) => DeclarationSource::Global {
is_explicit: Some(false),
},
Some(false) => {
DeclarationSource::Global { is_explicit: None }
}
None => DeclarationSource::Local,
},
},
);
} }
Err(e) => { Err(e) => {
return report_error( return report_error(
@ -2650,4 +2685,22 @@ impl<'a> Inferencer<'a> {
self.constrain(body.custom.unwrap(), orelse.custom.unwrap(), &body.location)?; self.constrain(body.custom.unwrap(), orelse.custom.unwrap(), &body.location)?;
Ok(body.custom.unwrap()) Ok(body.custom.unwrap())
} }
/// Determines whether the given `id` refers to a global symbol.
///
/// Returns `Some(true)` if `id` refers to a global variable, `Some(false)` if `id` refers to a
/// class/function, and `None` if `id` refers to a local symbol.
pub(super) fn is_id_global(&self, id: Ident) -> Option<bool> {
self.top_level
.definitions
.read()
.iter()
.map(|def| match *def.read() {
TopLevelDef::Class { name, .. } => (name, false),
TopLevelDef::Function { simple_name, .. } => (simple_name, false),
TopLevelDef::Variable { simple_name, .. } => (simple_name, true),
})
.find(|(global, _)| global == &id)
.map(|(_, has_explicit_prop)| has_explicit_prop)
}
} }

View File

@ -7,7 +7,7 @@ def output_int64(x: int64):
... ...
X: int32 = 0 X: int32 = 0
Y: int64 = int64(1) Y = int64(1)
def f(): def f():
global X, Y global X, Y

View File

@ -174,46 +174,49 @@ fn handle_typevar_definition(
fn handle_assignment_pattern( fn handle_assignment_pattern(
targets: &[Expr], targets: &[Expr],
value: &Expr, value: &Expr,
resolver: &(dyn SymbolResolver + Send + Sync), resolver: Arc<dyn SymbolResolver + Send + Sync>,
internal_resolver: &ResolverInternal, internal_resolver: &ResolverInternal,
def_list: &[Arc<RwLock<TopLevelDef>>], composer: &mut TopLevelComposer,
unifier: &mut Unifier,
primitives: &PrimitiveStore,
) -> Result<(), String> { ) -> Result<(), String> {
if targets.len() == 1 { if targets.len() == 1 {
match &targets[0].node { let target = &targets[0];
match &target.node {
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
let def_list = composer.extract_def_list();
let unifier = &mut composer.unifier;
let primitives = &composer.primitives_ty;
if let Ok(var) = if let Ok(var) =
handle_typevar_definition(value, resolver, def_list, unifier, primitives) handle_typevar_definition(value, &*resolver, &def_list, unifier, primitives)
{ {
internal_resolver.add_id_type(*id, var); internal_resolver.add_id_type(*id, var);
Ok(()) Ok(())
} else if let Ok(val) = parse_parameter_default_value(value, resolver) { } else if let Ok(val) = parse_parameter_default_value(value, &*resolver) {
internal_resolver.add_module_global(*id, val); internal_resolver.add_module_global(*id, val);
let (name, def_id, _) = composer
.register_top_level_var(
*id,
None,
Some(resolver.clone()),
"__main__",
target.location,
)
.unwrap();
internal_resolver.add_id_def(name, def_id);
Ok(()) Ok(())
} else { } else {
Err(format!("fails to evaluate this expression `{:?}` as a constant or generic parameter at {}", Err(format!("fails to evaluate this expression `{:?}` as a constant or generic parameter at {}",
targets[0].node, target.node,
targets[0].location, target.location,
)) ))
} }
} }
ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => { ExprKind::List { elts, .. } | ExprKind::Tuple { elts, .. } => {
handle_assignment_pattern( handle_assignment_pattern(elts, value, resolver, internal_resolver, composer)?;
elts,
value,
resolver,
internal_resolver,
def_list,
unifier,
primitives,
)?;
Ok(()) Ok(())
} }
_ => Err(format!( _ => Err(format!("assignment to {target:?} is not supported at {}", target.location)),
"assignment to {:?} is not supported at {}",
targets[0], targets[0].location
)),
} }
} else { } else {
match &value.node { match &value.node {
@ -223,11 +226,9 @@ fn handle_assignment_pattern(
handle_assignment_pattern( handle_assignment_pattern(
std::slice::from_ref(tar), std::slice::from_ref(tar),
val, val,
resolver, resolver.clone(),
internal_resolver, internal_resolver,
def_list, composer,
unifier,
primitives,
)?; )?;
} }
Ok(()) Ok(())
@ -248,8 +249,9 @@ fn handle_assignment_pattern(
fn handle_global_var( fn handle_global_var(
target: &Expr, target: &Expr,
value: Option<&Expr>, value: Option<&Expr>,
resolver: &(dyn SymbolResolver + Send + Sync), resolver: &Arc<dyn SymbolResolver + Send + Sync>,
internal_resolver: &ResolverInternal, internal_resolver: &ResolverInternal,
composer: &mut TopLevelComposer,
) -> Result<(), String> { ) -> Result<(), String> {
let ExprKind::Name { id, .. } = target.node else { let ExprKind::Name { id, .. } = target.node else {
return Err(format!( return Err(format!(
@ -262,8 +264,12 @@ fn handle_global_var(
return Err(format!("global variable `{id}` must be initialized in its definition")); return Err(format!("global variable `{id}` must be initialized in its definition"));
}; };
if let Ok(val) = parse_parameter_default_value(value, resolver) { if let Ok(val) = parse_parameter_default_value(value, &**resolver) {
internal_resolver.add_module_global(id, val); internal_resolver.add_module_global(id, val);
let (name, def_id, _) = composer
.register_top_level_var(id, None, Some(resolver.clone()), "__main__", target.location)
.unwrap();
internal_resolver.add_id_def(name, def_id);
Ok(()) Ok(())
} else { } else {
Err(format!( Err(format!(
@ -355,17 +361,12 @@ fn main() {
for stmt in parser_result { for stmt in parser_result {
match &stmt.node { match &stmt.node {
StmtKind::Assign { targets, value, .. } => { StmtKind::Assign { targets, value, .. } => {
let def_list = composer.extract_def_list();
let unifier = &mut composer.unifier;
let primitives = &composer.primitives_ty;
if let Err(err) = handle_assignment_pattern( if let Err(err) = handle_assignment_pattern(
targets, targets,
value, value,
resolver.as_ref(), resolver.clone(),
internal_resolver.as_ref(), internal_resolver.as_ref(),
&def_list, &mut composer,
unifier,
primitives,
) { ) {
panic!("{err}"); panic!("{err}");
} }
@ -375,16 +376,12 @@ fn main() {
if let Err(err) = handle_global_var( if let Err(err) = handle_global_var(
target, target,
value.as_ref().map(Box::as_ref), value.as_ref().map(Box::as_ref),
resolver.as_ref(), &resolver,
internal_resolver.as_ref(), internal_resolver.as_ref(),
&mut composer,
) { ) {
panic!("{err}"); panic!("{err}");
} }
let (name, def_id, _) = composer
.register_top_level(stmt, Some(resolver.clone()), "__main__", true)
.unwrap();
internal_resolver.add_id_def(name, def_id);
} }
// allow (and ignore) "from __future__ import annotations" // allow (and ignore) "from __future__ import annotations"