diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 78f427e..ba6c4fa 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -277,6 +277,10 @@ impl Nac3 { } }) } + // Allow global variable declaration with `Kernel` type annotation + StmtKind::AnnAssign { ref annotation, .. } => { + matches!(&annotation.node, ExprKind::Subscript { value, .. } if matches!(&value.node, ExprKind::Name {id, ..} if id == &"Kernel".into())) + } _ => false, }; @@ -522,8 +526,15 @@ impl Nac3 { as Arc; let name_to_pyid = Rc::new(name_to_pyid); let module_location = ast::Location::new(1, 1, stmt.location.file); - module_to_resolver_cache - .insert(module_id, (name_to_pyid.clone(), resolver.clone(), module_name.clone(), Some(module_location))); + module_to_resolver_cache.insert( + module_id, + ( + name_to_pyid.clone(), + resolver.clone(), + module_name.clone(), + Some(module_location), + ), + ); (name_to_pyid, resolver, module_name, Some(module_location)) }); @@ -599,15 +610,19 @@ impl Nac3 { } // Adding top level module definitions - for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in module_to_resolver_cache.into_iter() { - let def_id= composer.register_top_level_module( - module_name, - module_name_to_pyid, - module_resolver, - module_location - ).map_err(|e| { - CompileError::new_err(format!("compilation failed\n----------\n{e}")) - })?; + for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in + module_to_resolver_cache + { + let def_id = composer + .register_top_level_module( + &module_name, + &module_name_to_pyid, + module_resolver, + module_location, + ) + .map_err(|e| { + CompileError::new_err(format!("compilation failed\n----------\n{e}")) + })?; self.pyid_to_def.write().insert(module_id, def_id); } @@ -731,7 +746,9 @@ impl Nac3 { "Unsupported @rpc annotation on global variable", ))) } - TopLevelDef::Module { .. } => unreachable!("Type module cannot be decorated with @rpc"), + TopLevelDef::Module { .. } => { + unreachable!("Type module cannot be decorated with @rpc") + } } } } diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index d976866..4b398a9 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -23,7 +23,7 @@ use nac3core::{ inkwell::{ module::Linkage, types::{BasicType, BasicTypeEnum}, - values::BasicValueEnum, + values::{BasicValue, BasicValueEnum}, AddressSpace, }, nac3parser::ast::{self, StrRef}, @@ -674,6 +674,48 @@ impl InnerResolver { }) }); + // check if obj is module + if self.helper.id_fn.call1(py, (ty.clone(),))?.extract::(py)? + == self.primitive_ids.module + && self.pyid_to_def.read().contains_key(&py_obj_id) + { + let def_id = self.pyid_to_def.read()[&py_obj_id]; + let def = defs[def_id.0].read(); + let TopLevelDef::Module { name: module_name, module_id, attributes, methods, .. } = + &*def + else { + unreachable!("must be a module here"); + }; + // Construct the module return type + let mut module_attributes = HashMap::new(); + for (name, _) in attributes { + let attribute_obj = obj.getattr(name.to_string().as_str())?; + let attribute_ty = + self.get_obj_type(py, attribute_obj, unifier, defs, primitives)?; + if let Ok(attribute_ty) = attribute_ty { + module_attributes.insert(*name, (attribute_ty, false)); + } else { + return Ok(Err(format!("Unable to resolve {module_name}.{name}"))); + } + } + + for name in methods.keys() { + let method_obj = obj.getattr(name.to_string().as_str())?; + let method_ty = self.get_obj_type(py, method_obj, unifier, defs, primitives)?; + if let Ok(method_ty) = method_ty { + module_attributes.insert(*name, (method_ty, true)); + } else { + return Ok(Err(format!("Unable to resolve {module_name}.{name}"))); + } + } + + let module_ty = + TypeEnum::TModule { module_id: *module_id, attributes: module_attributes }; + + let ty = unifier.add_ty(module_ty); + return Ok(Ok(ty)); + } + if let Some(ty) = constructor_ty { self.pyid_to_type.write().insert(py_obj_id, ty); return Ok(Ok(ty)); @@ -1373,6 +1415,77 @@ impl InnerResolver { None => Ok(None), } } + } else if ty_id == self.primitive_ids.module { + let id_str = id.to_string(); + + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + + let top_level_defs = ctx.top_level.definitions.read(); + let ty = self + .get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)? + .unwrap(); + let ty = ctx + .get_llvm_type(generator, ty) + .into_pointer_type() + .get_element_type() + .into_struct_type(); + + { + if self.global_value_ids.read().contains_key(&id) { + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) + }); + return Ok(Some(global.as_pointer_value().into())); + } + self.global_value_ids.write().insert(id, obj.into()); + } + + let fields = { + let definition = + top_level_defs.get(self.pyid_to_def.read().get(&id).unwrap().0).unwrap().read(); + let TopLevelDef::Module { attributes, .. } = &*definition else { unreachable!() }; + attributes + .iter() + .filter_map(|f| { + let definition = top_level_defs.get(f.1 .0).unwrap().read(); + if let TopLevelDef::Variable { ty, .. } = &*definition { + Some((f.0, *ty)) + } else { + None + } + }) + .collect_vec() + }; + + let values: Result>, _> = fields + .iter() + .map(|(name, ty)| { + self.get_obj_value( + py, + obj.getattr(name.to_string().as_str())?, + ctx, + generator, + *ty, + ) + .map_err(|e| { + super::CompileError::new_err(format!("Error getting field {name}: {e}")) + }) + }) + .collect(); + let values = values?; + + if let Some(values) = values { + let val = ty.const_named_struct(&values); + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) + }); + global.set_initializer(&val); + Ok(Some(global.as_pointer_value().into())) + } else { + Ok(None) + } } else { let id_str = id.to_string(); @@ -1555,9 +1668,50 @@ impl SymbolResolver for Resolver { fn get_symbol_value<'ctx>( &self, id: StrRef, - _: &mut CodeGenContext<'ctx, '_>, - _: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, ) -> Option> { + if let Some(def_id) = self.0.id_to_def.read().get(&id) { + let top_levels = ctx.top_level.definitions.read(); + if matches!(&*top_levels[def_id.0].read(), TopLevelDef::Variable { .. }) { + let module_val = &self.0.module; + let ret = Python::with_gil(|py| -> PyResult> { + let module_val = module_val.as_ref(py); + + let ty = self.0.get_obj_type( + py, + module_val, + &mut ctx.unifier, + &top_levels, + &ctx.primitives, + )?; + if let Err(ty) = ty { + return Ok(Err(ty)); + } + let ty = ty.unwrap(); + let obj = self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap(); + let (idx, _) = ctx.get_attr_index(ty, id); + let ret = unsafe { + ctx.builder.build_gep( + obj.into_pointer_value(), + &[ + ctx.ctx.i32_type().const_zero(), + ctx.ctx.i32_type().const_int(idx as u64, false), + ], + id.to_string().as_str(), + ) + } + .unwrap(); + Ok(Ok(ret.as_basic_value_enum())) + }) + .unwrap(); + if ret.is_err() { + return None; + } + return Some(ret.unwrap().into()); + } + } + let sym_value = { let id_to_val = self.0.id_to_pyval.read(); id_to_val.get(&id).cloned() diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index a4ca27f..a6a0ce7 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -101,8 +101,9 @@ impl TopLevelComposer { let builtin_name_list = definition_ast_list .iter() .map(|def_ast| match *def_ast.0.read() { - TopLevelDef::Class { name, .. } - | TopLevelDef::Module { name, .. } => name.to_string(), + TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => { + name.to_string() + } TopLevelDef::Function { simple_name, .. } | TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(), }) @@ -205,29 +206,37 @@ impl TopLevelComposer { /// register top level modules pub fn register_top_level_module( &mut self, - module_name: String, - name_to_pyid: Rc>, + module_name: &str, + name_to_pyid: &Rc>, resolver: Arc, - location: Option + location: Option, ) -> Result { - let mut attributes: HashMap = HashMap::new(); + let mut methods: HashMap = HashMap::new(); + let mut attributes: Vec<(StrRef, DefinitionId)> = Vec::new(); + for (name, _) in name_to_pyid.iter() { if let Ok(def_id) = resolver.get_identifier_def(*name) { // Avoid repeated attribute instances resulting from multiple imports of same module if self.defined_names.contains(&format!("{module_name}.{name}")) { - attributes.insert(*name, def_id); + match &*self.definition_ast_list[def_id.0].0.read() { + TopLevelDef::Class { .. } | TopLevelDef::Function { .. } => { + methods.insert(*name, def_id); + } + _ => attributes.push((*name, def_id)), + } } }; } - let module_def = TopLevelDef::Module { - name: module_name.clone().into(), - module_id: DefinitionId(self.definition_ast_list.len()), - attributes, - resolver: Some(resolver), - loc: location + let module_def = TopLevelDef::Module { + name: module_name.to_string().into(), + module_id: DefinitionId(self.definition_ast_list.len()), + methods, + attributes, + resolver: Some(resolver), + loc: location, }; - self.definition_ast_list.push((Arc::new(RwLock::new(module_def)).into(), None)); + self.definition_ast_list.push((Arc::new(RwLock::new(module_def)), None)); Ok(DefinitionId(self.definition_ast_list.len() - 1)) } @@ -499,10 +508,10 @@ impl TopLevelComposer { self.analyze_top_level_class_definition()?; self.analyze_top_level_class_fields_methods()?; self.analyze_top_level_function()?; + self.analyze_top_level_variables()?; if inference { self.analyze_function_instance()?; } - self.analyze_top_level_variables()?; Ok(()) } @@ -1440,7 +1449,7 @@ impl TopLevelComposer { Ok(()) } - /// step 4, analyze and call type inferencer to fill the `instance_to_stmt` of + /// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of /// [`TopLevelDef::Function`] fn analyze_function_instance(&mut self) -> Result<(), HashSet> { // first get the class constructor type correct for the following type check in function body @@ -1971,7 +1980,7 @@ impl TopLevelComposer { Ok(()) } - /// Step 5. Analyze and populate the types of global variables. + /// Step 4. Analyze and populate the types of global variables. fn analyze_top_level_variables(&mut self) -> Result<(), HashSet> { let def_list = &self.definition_ast_list; let temp_def_list = self.extract_def_list(); @@ -1989,6 +1998,19 @@ impl TopLevelComposer { let resolver = &**resolver.as_ref().unwrap(); if let Some(ty_decl) = ty_decl { + let ty_decl = match &ty_decl.node { + ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if self.core_config.kernel_ann.map_or(false, |c| id == &c.into()) + ) => + { + slice + } + _ if self.core_config.kernel_ann.is_none() => ty_decl, + _ => unreachable!("Global variables should be annotated with Kernel[]"), // ignore fields annotated otherwise + }; + let ty_annotation = parse_ast_to_type_annotation_kinds( resolver, &temp_def_list, diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 72502aa..4ca5464 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -379,11 +379,12 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { - TopLevelDef::Module { name, attributes, .. } => { - let method_str = attributes.iter().map(|(n, _)| n.to_string()).collect_vec(); + TopLevelDef::Module { name, attributes, methods, .. } => { format!( - "Module {{\nname: {:?},\nattributes{:?}\n}}", - name, method_str + "Module {{\nname: {:?},\nattributes: {:?}\nmethods: {:?}\n}}", + name, + attributes.iter().map(|(n, _)| n.to_string()).collect_vec(), + methods.iter().map(|(n, _)| n.to_string()).collect_vec() ) } TopLevelDef::Class { diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index 88c007e..3ffd568 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -97,8 +97,10 @@ pub enum TopLevelDef { name: StrRef, /// Module ID used for [`TypeEnum`] module_id: DefinitionId, - /// DefinitionId of `TopLevelDef::{Class, Function, Variable}` within the module - attributes: HashMap, + /// `DefinitionId` of `TopLevelDef::{Class, Function}` within the module + methods: HashMap, + /// `DefinitionId` of `TopLevelDef::{Variable}` within the module + attributes: Vec<(StrRef, DefinitionId)>, /// Symbol resolver of the module defined the class. resolver: Option>, /// Definition location.