diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index ab5e4d80..f552229f 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -use nac3parser::ast::{fold::Fold, ExprKind}; +use nac3parser::ast::{fold::Fold, ExprKind, Ident}; use super::*; use crate::{ @@ -386,50 +386,17 @@ impl TopLevelComposer { let ExprKind::Name { id: name, .. } = target.node else { return Err(format!( "global variable declaration must be an identifier (at {})", - ast.location + target.location )); }; - if self.keyword_list.contains(&name) { - return Err(format!( - "cannot use keyword `{}` as a class name (at {})", - name, - ast.location - )); - } - - let global_var_name = if mod_path.is_empty() { - name.to_string() - } else { - format!("{mod_path}.{name}") - }; - if !defined_names.insert(global_var_name.clone()) { - return Err(format!( - "global variable `{}` defined twice (at {})", - global_var_name, - ast.location - )); - } - - let ty_to_be_unified = self.unifier.get_dummy_var().ty; - self.definition_ast_list.push(( - RwLock::new(Self::make_top_level_variable_def( - global_var_name, - name, - // dummy here, unify with correct type later, - ty_to_be_unified, - *(annotation.clone()), - resolver, - Some(ast.location), - )).into(), - None, - )); - - Ok(( + self.register_top_level_var( name, - DefinitionId(self.definition_ast_list.len() - 1), - Some(ty_to_be_unified), - )) + Some(annotation.as_ref().clone()), + resolver, + mod_path, + target.location, + ) } _ => Err(format!( @@ -439,6 +406,50 @@ impl TopLevelComposer { } } + /// Registers a top-level variable with the given `name` into the composer. + /// + /// `annotation` - The type annotation of the top-level variable, or [`None`] if no type + /// annotation is provided. + /// `location` - The location of the top-level variable. + pub fn register_top_level_var( + &mut self, + name: Ident, + annotation: Option, + resolver: Option>, + mod_path: &str, + location: Location, + ) -> Result<(StrRef, DefinitionId, Option), String> { + if self.keyword_list.contains(&name) { + return Err(format!("cannot use keyword `{name}` as a class name (at {location})")); + } + + let global_var_name = + if mod_path.is_empty() { name.to_string() } else { format!("{mod_path}.{name}") }; + + if !self.defined_names.insert(global_var_name.clone()) { + return Err(format!( + "global variable `{global_var_name}` defined twice (at {location})" + )); + } + + let ty_to_be_unified = self.unifier.get_dummy_var().ty; + self.definition_ast_list.push(( + RwLock::new(Self::make_top_level_variable_def( + global_var_name, + name, + // dummy here, unify with correct type later, + ty_to_be_unified, + annotation, + resolver, + Some(location), + )) + .into(), + None, + )); + + Ok((name, DefinitionId(self.definition_ast_list.len() - 1), Some(ty_to_be_unified))) + } + pub fn start_analysis(&mut self, inference: bool) -> Result<(), HashSet> { self.analyze_top_level_class_type_var()?; self.analyze_top_level_class_bases()?; @@ -2249,9 +2260,8 @@ impl TopLevelComposer { let primitives_store = &self.primitives_ty; let mut analyze = |variable_def: &Arc>| -> Result<_, HashSet> { - let variable_def = &mut *variable_def.write(); - - let TopLevelDef::Variable { ty: dummy_ty, ty_decl, resolver, loc, .. } = variable_def + let TopLevelDef::Variable { ty: dummy_ty, ty_decl, resolver, loc, .. } = + &*variable_def.read() else { // not top level variable def, skip return Ok(()); @@ -2259,25 +2269,28 @@ impl TopLevelComposer { let resolver = &**resolver.as_ref().unwrap(); - let ty_annotation = parse_ast_to_type_annotation_kinds( - resolver, - &temp_def_list, - unifier, - primitives_store, - ty_decl, - HashMap::new(), - )?; - let ty_from_ty_annotation = get_type_from_type_annotation_kinds( - &temp_def_list, - unifier, - primitives_store, - &ty_annotation, - &mut None, - )?; + if let Some(ty_decl) = ty_decl { + let ty_annotation = parse_ast_to_type_annotation_kinds( + resolver, + &temp_def_list, + unifier, + primitives_store, + ty_decl, + HashMap::new(), + )?; + let ty_from_ty_annotation = get_type_from_type_annotation_kinds( + &temp_def_list, + unifier, + primitives_store, + &ty_annotation, + &mut None, + )?; + + unifier.unify(*dummy_ty, ty_from_ty_annotation).map_err(|e| { + HashSet::from([e.at(Some(loc.unwrap())).to_display(unifier).to_string()]) + })?; + } - unifier.unify(*dummy_ty, ty_from_ty_annotation).map_err(|e| { - HashSet::from([e.at(Some(loc.unwrap())).to_display(unifier).to_string()]) - })?; Ok(()) }; diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index d674c51b..f2948754 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -600,7 +600,7 @@ impl TopLevelComposer { name: String, simple_name: StrRef, ty: Type, - ty_decl: Expr, + ty_decl: Option, resolver: Option>, loc: Option, ) -> TopLevelDef { diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 8241c3ac..e0479a88 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -158,8 +158,8 @@ pub enum TopLevelDef { /// Type of the global variable. ty: Type, - /// The declared type of the global variable. - ty_decl: Expr, + /// The declared type of the global variable, or [`None`] if no type annotation is provided. + ty_decl: Option, /// Symbol resolver of the module defined the class. resolver: Option>, diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index e27c5e1f..9b1a601c 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -248,8 +248,9 @@ fn handle_assignment_pattern( fn handle_global_var( target: &Expr, value: Option<&Expr>, - resolver: &(dyn SymbolResolver + Send + Sync), + resolver: &Arc, internal_resolver: &ResolverInternal, + composer: &mut TopLevelComposer, ) -> Result<(), String> { let ExprKind::Name { id, .. } = target.node else { return Err(format!( @@ -262,8 +263,12 @@ fn handle_global_var( return Err(format!("global variable `{id}` must be initialized in its definition")); }; - if let Ok(val) = parse_parameter_default_value(value, resolver) { + if let Ok(val) = parse_parameter_default_value(value, &**resolver) { internal_resolver.add_module_global(id, val); + let (name, def_id, _) = composer + .register_top_level_var(id, None, Some(resolver.clone()), "__main__", target.location) + .unwrap(); + internal_resolver.add_id_def(name, def_id); Ok(()) } else { Err(format!( @@ -375,16 +380,12 @@ fn main() { if let Err(err) = handle_global_var( target, value.as_ref().map(Box::as_ref), - resolver.as_ref(), + &resolver, internal_resolver.as_ref(), + &mut composer, ) { panic!("{err}"); } - - let (name, def_id, _) = composer - .register_top_level(stmt, Some(resolver.clone()), "__main__", true) - .unwrap(); - internal_resolver.add_id_def(name, def_id); } // allow (and ignore) "from __future__ import annotations"