From 3ce2eddcdc5db3212d16c396a5f8d93375a2b6e0 Mon Sep 17 00:00:00 2001
From: David Mak <chmakac@connect.ust.hk>
Date: Sat, 5 Oct 2024 17:07:13 +0800
Subject: [PATCH] [core] typecheck/type_inferencer: Infer whether variables are
 global

---
 nac3core/src/typecheck/function_check.rs      | 21 +++++++++--
 nac3core/src/typecheck/type_inferencer/mod.rs | 37 ++++++++++++++++++-
 2 files changed, 53 insertions(+), 5 deletions(-)

diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs
index c4626f2f6..25fc6b764 100644
--- a/nac3core/src/typecheck/function_check.rs
+++ b/nac3core/src/typecheck/function_check.rs
@@ -104,7 +104,22 @@ impl<'a> Inferencer<'a> {
                         *id,
                     ) {
                         Ok(_) => {
-                            self.defined_identifiers.insert(*id, IdentifierInfo::default());
+                            let is_global = self.is_id_global(*id);
+
+                            defined_identifiers.insert(
+                                *id,
+                                IdentifierInfo {
+                                    source: match is_global {
+                                        Some(true) => {
+                                            DeclarationSource::Global { is_explicit: Some(false) }
+                                        }
+                                        Some(false) => {
+                                            DeclarationSource::Global { is_explicit: None }
+                                        }
+                                        None => DeclarationSource::Local,
+                                    },
+                                },
+                            );
                         }
                         Err(e) => {
                             return Err(HashSet::from([format!(
@@ -370,7 +385,7 @@ impl<'a> Inferencer<'a> {
                     if let Some(id_info) = defined_identifiers.get(id) {
                         if id_info.source == DeclarationSource::Local {
                             return Err(HashSet::from([format!(
-                                "name '{id}' is assigned to before global declaration at {}",
+                                "name '{id}' is referenced prior to global declaration at {}",
                                 stmt.location,
                             )]));
                         }
@@ -385,7 +400,7 @@ impl<'a> Inferencer<'a> {
                         *id,
                     ) {
                         Ok(_) => {
-                            self.defined_identifiers.insert(
+                            defined_identifiers.insert(
                                 *id,
                                 IdentifierInfo {
                                     source: DeclarationSource::Global { is_explicit: Some(true) },
diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs
index dde45d794..e5a6cf892 100644
--- a/nac3core/src/typecheck/type_inferencer/mod.rs
+++ b/nac3core/src/typecheck/type_inferencer/mod.rs
@@ -12,7 +12,7 @@ use itertools::{izip, Itertools};
 use nac3parser::ast::{
     self,
     fold::{self, Fold},
-    Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef,
+    Arguments, Comprehension, ExprContext, ExprKind, Ident, Located, Location, StrRef,
 };
 
 use super::{
@@ -594,7 +594,22 @@ impl<'a> Fold<()> for Inferencer<'a> {
                             *id,
                         ) {
                             Ok(_) => {
-                                self.defined_identifiers.insert(*id, IdentifierInfo::default());
+                                let is_global = self.is_id_global(*id);
+
+                                self.defined_identifiers.insert(
+                                    *id,
+                                    IdentifierInfo {
+                                        source: match is_global {
+                                            Some(true) => DeclarationSource::Global {
+                                                is_explicit: Some(false),
+                                            },
+                                            Some(false) => {
+                                                DeclarationSource::Global { is_explicit: None }
+                                            }
+                                            None => DeclarationSource::Local,
+                                        },
+                                    },
+                                );
                             }
                             Err(e) => {
                                 return report_error(
@@ -2670,4 +2685,22 @@ impl<'a> Inferencer<'a> {
         self.constrain(body.custom.unwrap(), orelse.custom.unwrap(), &body.location)?;
         Ok(body.custom.unwrap())
     }
+
+    /// Determines whether the given `id` refers to a global symbol.
+    ///
+    /// Returns `Some(true)` if `id` refers to a global variable, `Some(false)` if `id` refers to a
+    /// class/function, and `None` if `id` refers to a local symbol.
+    pub(super) fn is_id_global(&self, id: Ident) -> Option<bool> {
+        self.top_level
+            .definitions
+            .read()
+            .iter()
+            .map(|def| match *def.read() {
+                TopLevelDef::Class { name, .. } => (name, false),
+                TopLevelDef::Function { simple_name, .. } => (simple_name, false),
+                TopLevelDef::Variable { simple_name, .. } => (simple_name, true),
+            })
+            .find(|(global, _)| global == &id)
+            .map(|(_, has_explicit_prop)| has_explicit_prop)
+    }
 }