From 32f24261f280cfe36d75a08313697e540822655e Mon Sep 17 00:00:00 2001
From: abdul124 <ar@m-labs.hk>
Date: Thu, 16 Jan 2025 11:08:55 +0800
Subject: [PATCH] [artiq] add global variables to modules

---
 nac3artiq/src/lib.rs              |  41 +++++---
 nac3artiq/src/symbol_resolver.rs  | 160 +++++++++++++++++++++++++++++-
 nac3core/src/toplevel/composer.rs |  56 +++++++----
 nac3core/src/toplevel/helper.rs   |   9 +-
 nac3core/src/toplevel/mod.rs      |   6 +-
 5 files changed, 234 insertions(+), 38 deletions(-)

diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs
index 78f427e5..ba6c4fae 100644
--- a/nac3artiq/src/lib.rs
+++ b/nac3artiq/src/lib.rs
@@ -277,6 +277,10 @@ impl Nac3 {
                         }
                     })
                 }
+                // Allow global variable declaration with `Kernel` type annotation
+                StmtKind::AnnAssign { ref annotation, .. } => {
+                    matches!(&annotation.node, ExprKind::Subscript { value, .. } if matches!(&value.node, ExprKind::Name {id, ..} if id == &"Kernel".into()))
+                }
                 _ => false,
             };
 
@@ -522,8 +526,15 @@ impl Nac3 {
                         as Arc<dyn SymbolResolver + Send + Sync>;
                     let name_to_pyid = Rc::new(name_to_pyid);
                     let module_location = ast::Location::new(1, 1, stmt.location.file);
-                    module_to_resolver_cache
-                        .insert(module_id, (name_to_pyid.clone(), resolver.clone(), module_name.clone(), Some(module_location)));
+                    module_to_resolver_cache.insert(
+                        module_id,
+                        (
+                            name_to_pyid.clone(),
+                            resolver.clone(),
+                            module_name.clone(),
+                            Some(module_location),
+                        ),
+                    );
                     (name_to_pyid, resolver, module_name, Some(module_location))
                 });
 
@@ -599,15 +610,19 @@ impl Nac3 {
         }
 
         // Adding top level module definitions
-        for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in module_to_resolver_cache.into_iter() {
-            let def_id= composer.register_top_level_module(
-                module_name,
-                module_name_to_pyid,
-                module_resolver,
-                module_location
-            ).map_err(|e| {
-                CompileError::new_err(format!("compilation failed\n----------\n{e}"))
-            })?;
+        for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in
+            module_to_resolver_cache
+        {
+            let def_id = composer
+                .register_top_level_module(
+                    &module_name,
+                    &module_name_to_pyid,
+                    module_resolver,
+                    module_location,
+                )
+                .map_err(|e| {
+                    CompileError::new_err(format!("compilation failed\n----------\n{e}"))
+                })?;
 
             self.pyid_to_def.write().insert(module_id, def_id);
         }
@@ -731,7 +746,9 @@ impl Nac3 {
                             "Unsupported @rpc annotation on global variable",
                         )))
                     }
-                    TopLevelDef::Module { .. } => unreachable!("Type module cannot be decorated with @rpc"),
+                    TopLevelDef::Module { .. } => {
+                        unreachable!("Type module cannot be decorated with @rpc")
+                    }
                 }
             }
         }
diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs
index d9768669..4b398a9b 100644
--- a/nac3artiq/src/symbol_resolver.rs
+++ b/nac3artiq/src/symbol_resolver.rs
@@ -23,7 +23,7 @@ use nac3core::{
     inkwell::{
         module::Linkage,
         types::{BasicType, BasicTypeEnum},
-        values::BasicValueEnum,
+        values::{BasicValue, BasicValueEnum},
         AddressSpace,
     },
     nac3parser::ast::{self, StrRef},
@@ -674,6 +674,48 @@ impl InnerResolver {
             })
         });
 
+        // check if obj is module
+        if self.helper.id_fn.call1(py, (ty.clone(),))?.extract::<u64>(py)?
+            == self.primitive_ids.module
+            && self.pyid_to_def.read().contains_key(&py_obj_id)
+        {
+            let def_id = self.pyid_to_def.read()[&py_obj_id];
+            let def = defs[def_id.0].read();
+            let TopLevelDef::Module { name: module_name, module_id, attributes, methods, .. } =
+                &*def
+            else {
+                unreachable!("must be a module here");
+            };
+            // Construct the module return type
+            let mut module_attributes = HashMap::new();
+            for (name, _) in attributes {
+                let attribute_obj = obj.getattr(name.to_string().as_str())?;
+                let attribute_ty =
+                    self.get_obj_type(py, attribute_obj, unifier, defs, primitives)?;
+                if let Ok(attribute_ty) = attribute_ty {
+                    module_attributes.insert(*name, (attribute_ty, false));
+                } else {
+                    return Ok(Err(format!("Unable to resolve {module_name}.{name}")));
+                }
+            }
+
+            for name in methods.keys() {
+                let method_obj = obj.getattr(name.to_string().as_str())?;
+                let method_ty = self.get_obj_type(py, method_obj, unifier, defs, primitives)?;
+                if let Ok(method_ty) = method_ty {
+                    module_attributes.insert(*name, (method_ty, true));
+                } else {
+                    return Ok(Err(format!("Unable to resolve {module_name}.{name}")));
+                }
+            }
+
+            let module_ty =
+                TypeEnum::TModule { module_id: *module_id, attributes: module_attributes };
+
+            let ty = unifier.add_ty(module_ty);
+            return Ok(Ok(ty));
+        }
+
         if let Some(ty) = constructor_ty {
             self.pyid_to_type.write().insert(py_obj_id, ty);
             return Ok(Ok(ty));
@@ -1373,6 +1415,77 @@ impl InnerResolver {
                     None => Ok(None),
                 }
             }
+        } else if ty_id == self.primitive_ids.module {
+            let id_str = id.to_string();
+
+            if let Some(global) = ctx.module.get_global(&id_str) {
+                return Ok(Some(global.as_pointer_value().into()));
+            }
+
+            let top_level_defs = ctx.top_level.definitions.read();
+            let ty = self
+                .get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)?
+                .unwrap();
+            let ty = ctx
+                .get_llvm_type(generator, ty)
+                .into_pointer_type()
+                .get_element_type()
+                .into_struct_type();
+
+            {
+                if self.global_value_ids.read().contains_key(&id) {
+                    let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
+                        ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str)
+                    });
+                    return Ok(Some(global.as_pointer_value().into()));
+                }
+                self.global_value_ids.write().insert(id, obj.into());
+            }
+
+            let fields = {
+                let definition =
+                    top_level_defs.get(self.pyid_to_def.read().get(&id).unwrap().0).unwrap().read();
+                let TopLevelDef::Module { attributes, .. } = &*definition else { unreachable!() };
+                attributes
+                    .iter()
+                    .filter_map(|f| {
+                        let definition = top_level_defs.get(f.1 .0).unwrap().read();
+                        if let TopLevelDef::Variable { ty, .. } = &*definition {
+                            Some((f.0, *ty))
+                        } else {
+                            None
+                        }
+                    })
+                    .collect_vec()
+            };
+
+            let values: Result<Option<Vec<_>>, _> = fields
+                .iter()
+                .map(|(name, ty)| {
+                    self.get_obj_value(
+                        py,
+                        obj.getattr(name.to_string().as_str())?,
+                        ctx,
+                        generator,
+                        *ty,
+                    )
+                    .map_err(|e| {
+                        super::CompileError::new_err(format!("Error getting field {name}: {e}"))
+                    })
+                })
+                .collect();
+            let values = values?;
+
+            if let Some(values) = values {
+                let val = ty.const_named_struct(&values);
+                let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
+                    ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str)
+                });
+                global.set_initializer(&val);
+                Ok(Some(global.as_pointer_value().into()))
+            } else {
+                Ok(None)
+            }
         } else {
             let id_str = id.to_string();
 
@@ -1555,9 +1668,50 @@ impl SymbolResolver for Resolver {
     fn get_symbol_value<'ctx>(
         &self,
         id: StrRef,
-        _: &mut CodeGenContext<'ctx, '_>,
-        _: &mut dyn CodeGenerator,
+        ctx: &mut CodeGenContext<'ctx, '_>,
+        generator: &mut dyn CodeGenerator,
     ) -> Option<ValueEnum<'ctx>> {
+        if let Some(def_id) = self.0.id_to_def.read().get(&id) {
+            let top_levels = ctx.top_level.definitions.read();
+            if matches!(&*top_levels[def_id.0].read(), TopLevelDef::Variable { .. }) {
+                let module_val = &self.0.module;
+                let ret = Python::with_gil(|py| -> PyResult<Result<BasicValueEnum, String>> {
+                    let module_val = module_val.as_ref(py);
+
+                    let ty = self.0.get_obj_type(
+                        py,
+                        module_val,
+                        &mut ctx.unifier,
+                        &top_levels,
+                        &ctx.primitives,
+                    )?;
+                    if let Err(ty) = ty {
+                        return Ok(Err(ty));
+                    }
+                    let ty = ty.unwrap();
+                    let obj = self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap();
+                    let (idx, _) = ctx.get_attr_index(ty, id);
+                    let ret = unsafe {
+                        ctx.builder.build_gep(
+                            obj.into_pointer_value(),
+                            &[
+                                ctx.ctx.i32_type().const_zero(),
+                                ctx.ctx.i32_type().const_int(idx as u64, false),
+                            ],
+                            id.to_string().as_str(),
+                        )
+                    }
+                    .unwrap();
+                    Ok(Ok(ret.as_basic_value_enum()))
+                })
+                .unwrap();
+                if ret.is_err() {
+                    return None;
+                }
+                return Some(ret.unwrap().into());
+            }
+        }
+
         let sym_value = {
             let id_to_val = self.0.id_to_pyval.read();
             id_to_val.get(&id).cloned()
diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs
index a4ca27f1..a6a0ce76 100644
--- a/nac3core/src/toplevel/composer.rs
+++ b/nac3core/src/toplevel/composer.rs
@@ -101,8 +101,9 @@ impl TopLevelComposer {
         let builtin_name_list = definition_ast_list
             .iter()
             .map(|def_ast| match *def_ast.0.read() {
-                TopLevelDef::Class { name, .. } 
-                | TopLevelDef::Module { name, .. } => name.to_string(),
+                TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => {
+                    name.to_string()
+                }
                 TopLevelDef::Function { simple_name, .. }
                 | TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(),
             })
@@ -205,29 +206,37 @@ impl TopLevelComposer {
     /// register top level modules
     pub fn register_top_level_module(
         &mut self,
-        module_name: String,
-        name_to_pyid: Rc<HashMap<StrRef, u64>>,
+        module_name: &str,
+        name_to_pyid: &Rc<HashMap<StrRef, u64>>,
         resolver: Arc<dyn SymbolResolver + Send + Sync>,
-        location: Option<Location>
+        location: Option<Location>,
     ) -> Result<DefinitionId, String> {
-        let mut attributes: HashMap<StrRef, DefinitionId> = HashMap::new();
+        let mut methods: HashMap<StrRef, DefinitionId> = HashMap::new();
+        let mut attributes: Vec<(StrRef, DefinitionId)> = Vec::new();
+
         for (name, _) in name_to_pyid.iter() {
             if let Ok(def_id) = resolver.get_identifier_def(*name) {
                 // Avoid repeated attribute instances resulting from multiple imports of same module
                 if self.defined_names.contains(&format!("{module_name}.{name}")) {
-                    attributes.insert(*name, def_id);
+                    match &*self.definition_ast_list[def_id.0].0.read() {
+                        TopLevelDef::Class { .. } | TopLevelDef::Function { .. } => {
+                            methods.insert(*name, def_id);
+                        }
+                        _ => attributes.push((*name, def_id)),
+                    }
                 }
             };
         }
-        let module_def = TopLevelDef::Module { 
-            name: module_name.clone().into(), 
-            module_id: DefinitionId(self.definition_ast_list.len()), 
-            attributes, 
-            resolver: Some(resolver), 
-            loc: location
+        let module_def = TopLevelDef::Module {
+            name: module_name.to_string().into(),
+            module_id: DefinitionId(self.definition_ast_list.len()),
+            methods,
+            attributes,
+            resolver: Some(resolver),
+            loc: location,
         };
 
-        self.definition_ast_list.push((Arc::new(RwLock::new(module_def)).into(), None));
+        self.definition_ast_list.push((Arc::new(RwLock::new(module_def)), None));
         Ok(DefinitionId(self.definition_ast_list.len() - 1))
     }
 
@@ -499,10 +508,10 @@ impl TopLevelComposer {
         self.analyze_top_level_class_definition()?;
         self.analyze_top_level_class_fields_methods()?;
         self.analyze_top_level_function()?;
+        self.analyze_top_level_variables()?;
         if inference {
             self.analyze_function_instance()?;
         }
-        self.analyze_top_level_variables()?;
         Ok(())
     }
 
@@ -1440,7 +1449,7 @@ impl TopLevelComposer {
         Ok(())
     }
 
-    /// step 4, analyze and call type inferencer to fill the `instance_to_stmt` of
+    /// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of
     /// [`TopLevelDef::Function`]
     fn analyze_function_instance(&mut self) -> Result<(), HashSet<String>> {
         // first get the class constructor type correct for the following type check in function body
@@ -1971,7 +1980,7 @@ impl TopLevelComposer {
         Ok(())
     }
 
-    /// Step 5. Analyze and populate the types of global variables.
+    /// Step 4. Analyze and populate the types of global variables.
     fn analyze_top_level_variables(&mut self) -> Result<(), HashSet<String>> {
         let def_list = &self.definition_ast_list;
         let temp_def_list = self.extract_def_list();
@@ -1989,6 +1998,19 @@ impl TopLevelComposer {
             let resolver = &**resolver.as_ref().unwrap();
 
             if let Some(ty_decl) = ty_decl {
+                let ty_decl = match &ty_decl.node {
+                    ExprKind::Subscript { value, slice, .. }
+                        if matches!(
+                            &value.node,
+                            ast::ExprKind::Name { id, .. } if self.core_config.kernel_ann.map_or(false, |c| id == &c.into())
+                        ) =>
+                    {
+                        slice
+                    }
+                    _ if self.core_config.kernel_ann.is_none() => ty_decl,
+                    _ => unreachable!("Global variables should be annotated with Kernel[]"), // ignore fields annotated otherwise
+                };
+
                 let ty_annotation = parse_ast_to_type_annotation_kinds(
                     resolver,
                     &temp_def_list,
diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs
index 72502aa4..4ca5464f 100644
--- a/nac3core/src/toplevel/helper.rs
+++ b/nac3core/src/toplevel/helper.rs
@@ -379,11 +379,12 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef
 impl TopLevelDef {
     pub fn to_string(&self, unifier: &mut Unifier) -> String {
         match self {
-            TopLevelDef::Module { name, attributes, .. } => {
-                let method_str = attributes.iter().map(|(n, _)| n.to_string()).collect_vec();
+            TopLevelDef::Module { name, attributes, methods, .. } => {
                 format!(
-                    "Module {{\nname: {:?},\nattributes{:?}\n}}",
-                    name, method_str
+                    "Module {{\nname: {:?},\nattributes: {:?}\nmethods: {:?}\n}}",
+                    name,
+                    attributes.iter().map(|(n, _)| n.to_string()).collect_vec(),
+                    methods.iter().map(|(n, _)| n.to_string()).collect_vec()
                 )
             }
             TopLevelDef::Class {
diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs
index 88c007ec..3ffd568a 100644
--- a/nac3core/src/toplevel/mod.rs
+++ b/nac3core/src/toplevel/mod.rs
@@ -97,8 +97,10 @@ pub enum TopLevelDef {
         name: StrRef,
         /// Module ID used for [`TypeEnum`]
         module_id: DefinitionId,
-        /// DefinitionId of `TopLevelDef::{Class, Function, Variable}` within the module
-        attributes: HashMap<StrRef, DefinitionId>,
+        /// `DefinitionId` of `TopLevelDef::{Class, Function}` within the module
+        methods: HashMap<StrRef, DefinitionId>,
+        /// `DefinitionId` of `TopLevelDef::{Variable}` within the module
+        attributes: Vec<(StrRef, DefinitionId)>,
         /// Symbol resolver of the module defined the class.
         resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
         /// Definition location.