[artiq] add global variables to modules

This commit is contained in:
abdul124 2025-01-16 11:08:55 +08:00
parent ce40a46f8a
commit 32f24261f2
5 changed files with 234 additions and 38 deletions

View File

@ -277,6 +277,10 @@ impl Nac3 {
} }
}) })
} }
// Allow global variable declaration with `Kernel` type annotation
StmtKind::AnnAssign { ref annotation, .. } => {
matches!(&annotation.node, ExprKind::Subscript { value, .. } if matches!(&value.node, ExprKind::Name {id, ..} if id == &"Kernel".into()))
}
_ => false, _ => false,
}; };
@ -522,8 +526,15 @@ impl Nac3 {
as Arc<dyn SymbolResolver + Send + Sync>; as Arc<dyn SymbolResolver + Send + Sync>;
let name_to_pyid = Rc::new(name_to_pyid); let name_to_pyid = Rc::new(name_to_pyid);
let module_location = ast::Location::new(1, 1, stmt.location.file); let module_location = ast::Location::new(1, 1, stmt.location.file);
module_to_resolver_cache module_to_resolver_cache.insert(
.insert(module_id, (name_to_pyid.clone(), resolver.clone(), module_name.clone(), Some(module_location))); module_id,
(
name_to_pyid.clone(),
resolver.clone(),
module_name.clone(),
Some(module_location),
),
);
(name_to_pyid, resolver, module_name, Some(module_location)) (name_to_pyid, resolver, module_name, Some(module_location))
}); });
@ -599,15 +610,19 @@ impl Nac3 {
} }
// Adding top level module definitions // Adding top level module definitions
for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in module_to_resolver_cache.into_iter() { for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in
let def_id= composer.register_top_level_module( module_to_resolver_cache
module_name, {
module_name_to_pyid, let def_id = composer
module_resolver, .register_top_level_module(
module_location &module_name,
).map_err(|e| { &module_name_to_pyid,
CompileError::new_err(format!("compilation failed\n----------\n{e}")) module_resolver,
})?; module_location,
)
.map_err(|e| {
CompileError::new_err(format!("compilation failed\n----------\n{e}"))
})?;
self.pyid_to_def.write().insert(module_id, def_id); self.pyid_to_def.write().insert(module_id, def_id);
} }
@ -731,7 +746,9 @@ impl Nac3 {
"Unsupported @rpc annotation on global variable", "Unsupported @rpc annotation on global variable",
))) )))
} }
TopLevelDef::Module { .. } => unreachable!("Type module cannot be decorated with @rpc"), TopLevelDef::Module { .. } => {
unreachable!("Type module cannot be decorated with @rpc")
}
} }
} }
} }

View File

@ -23,7 +23,7 @@ use nac3core::{
inkwell::{ inkwell::{
module::Linkage, module::Linkage,
types::{BasicType, BasicTypeEnum}, types::{BasicType, BasicTypeEnum},
values::BasicValueEnum, values::{BasicValue, BasicValueEnum},
AddressSpace, AddressSpace,
}, },
nac3parser::ast::{self, StrRef}, nac3parser::ast::{self, StrRef},
@ -674,6 +674,48 @@ impl InnerResolver {
}) })
}); });
// check if obj is module
if self.helper.id_fn.call1(py, (ty.clone(),))?.extract::<u64>(py)?
== self.primitive_ids.module
&& self.pyid_to_def.read().contains_key(&py_obj_id)
{
let def_id = self.pyid_to_def.read()[&py_obj_id];
let def = defs[def_id.0].read();
let TopLevelDef::Module { name: module_name, module_id, attributes, methods, .. } =
&*def
else {
unreachable!("must be a module here");
};
// Construct the module return type
let mut module_attributes = HashMap::new();
for (name, _) in attributes {
let attribute_obj = obj.getattr(name.to_string().as_str())?;
let attribute_ty =
self.get_obj_type(py, attribute_obj, unifier, defs, primitives)?;
if let Ok(attribute_ty) = attribute_ty {
module_attributes.insert(*name, (attribute_ty, false));
} else {
return Ok(Err(format!("Unable to resolve {module_name}.{name}")));
}
}
for name in methods.keys() {
let method_obj = obj.getattr(name.to_string().as_str())?;
let method_ty = self.get_obj_type(py, method_obj, unifier, defs, primitives)?;
if let Ok(method_ty) = method_ty {
module_attributes.insert(*name, (method_ty, true));
} else {
return Ok(Err(format!("Unable to resolve {module_name}.{name}")));
}
}
let module_ty =
TypeEnum::TModule { module_id: *module_id, attributes: module_attributes };
let ty = unifier.add_ty(module_ty);
return Ok(Ok(ty));
}
if let Some(ty) = constructor_ty { if let Some(ty) = constructor_ty {
self.pyid_to_type.write().insert(py_obj_id, ty); self.pyid_to_type.write().insert(py_obj_id, ty);
return Ok(Ok(ty)); return Ok(Ok(ty));
@ -1373,6 +1415,77 @@ impl InnerResolver {
None => Ok(None), None => Ok(None),
} }
} }
} else if ty_id == self.primitive_ids.module {
let id_str = id.to_string();
if let Some(global) = ctx.module.get_global(&id_str) {
return Ok(Some(global.as_pointer_value().into()));
}
let top_level_defs = ctx.top_level.definitions.read();
let ty = self
.get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)?
.unwrap();
let ty = ctx
.get_llvm_type(generator, ty)
.into_pointer_type()
.get_element_type()
.into_struct_type();
{
if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str)
});
return Ok(Some(global.as_pointer_value().into()));
}
self.global_value_ids.write().insert(id, obj.into());
}
let fields = {
let definition =
top_level_defs.get(self.pyid_to_def.read().get(&id).unwrap().0).unwrap().read();
let TopLevelDef::Module { attributes, .. } = &*definition else { unreachable!() };
attributes
.iter()
.filter_map(|f| {
let definition = top_level_defs.get(f.1 .0).unwrap().read();
if let TopLevelDef::Variable { ty, .. } = &*definition {
Some((f.0, *ty))
} else {
None
}
})
.collect_vec()
};
let values: Result<Option<Vec<_>>, _> = fields
.iter()
.map(|(name, ty)| {
self.get_obj_value(
py,
obj.getattr(name.to_string().as_str())?,
ctx,
generator,
*ty,
)
.map_err(|e| {
super::CompileError::new_err(format!("Error getting field {name}: {e}"))
})
})
.collect();
let values = values?;
if let Some(values) = values {
let val = ty.const_named_struct(&values);
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str)
});
global.set_initializer(&val);
Ok(Some(global.as_pointer_value().into()))
} else {
Ok(None)
}
} else { } else {
let id_str = id.to_string(); let id_str = id.to_string();
@ -1555,9 +1668,50 @@ impl SymbolResolver for Resolver {
fn get_symbol_value<'ctx>( fn get_symbol_value<'ctx>(
&self, &self,
id: StrRef, id: StrRef,
_: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> { ) -> Option<ValueEnum<'ctx>> {
if let Some(def_id) = self.0.id_to_def.read().get(&id) {
let top_levels = ctx.top_level.definitions.read();
if matches!(&*top_levels[def_id.0].read(), TopLevelDef::Variable { .. }) {
let module_val = &self.0.module;
let ret = Python::with_gil(|py| -> PyResult<Result<BasicValueEnum, String>> {
let module_val = module_val.as_ref(py);
let ty = self.0.get_obj_type(
py,
module_val,
&mut ctx.unifier,
&top_levels,
&ctx.primitives,
)?;
if let Err(ty) = ty {
return Ok(Err(ty));
}
let ty = ty.unwrap();
let obj = self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap();
let (idx, _) = ctx.get_attr_index(ty, id);
let ret = unsafe {
ctx.builder.build_gep(
obj.into_pointer_value(),
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(idx as u64, false),
],
id.to_string().as_str(),
)
}
.unwrap();
Ok(Ok(ret.as_basic_value_enum()))
})
.unwrap();
if ret.is_err() {
return None;
}
return Some(ret.unwrap().into());
}
}
let sym_value = { let sym_value = {
let id_to_val = self.0.id_to_pyval.read(); let id_to_val = self.0.id_to_pyval.read();
id_to_val.get(&id).cloned() id_to_val.get(&id).cloned()

View File

@ -101,8 +101,9 @@ impl TopLevelComposer {
let builtin_name_list = definition_ast_list let builtin_name_list = definition_ast_list
.iter() .iter()
.map(|def_ast| match *def_ast.0.read() { .map(|def_ast| match *def_ast.0.read() {
TopLevelDef::Class { name, .. } TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => {
| TopLevelDef::Module { name, .. } => name.to_string(), name.to_string()
}
TopLevelDef::Function { simple_name, .. } TopLevelDef::Function { simple_name, .. }
| TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(), | TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(),
}) })
@ -205,29 +206,37 @@ impl TopLevelComposer {
/// register top level modules /// register top level modules
pub fn register_top_level_module( pub fn register_top_level_module(
&mut self, &mut self,
module_name: String, module_name: &str,
name_to_pyid: Rc<HashMap<StrRef, u64>>, name_to_pyid: &Rc<HashMap<StrRef, u64>>,
resolver: Arc<dyn SymbolResolver + Send + Sync>, resolver: Arc<dyn SymbolResolver + Send + Sync>,
location: Option<Location> location: Option<Location>,
) -> Result<DefinitionId, String> { ) -> Result<DefinitionId, String> {
let mut attributes: HashMap<StrRef, DefinitionId> = HashMap::new(); let mut methods: HashMap<StrRef, DefinitionId> = HashMap::new();
let mut attributes: Vec<(StrRef, DefinitionId)> = Vec::new();
for (name, _) in name_to_pyid.iter() { for (name, _) in name_to_pyid.iter() {
if let Ok(def_id) = resolver.get_identifier_def(*name) { if let Ok(def_id) = resolver.get_identifier_def(*name) {
// Avoid repeated attribute instances resulting from multiple imports of same module // Avoid repeated attribute instances resulting from multiple imports of same module
if self.defined_names.contains(&format!("{module_name}.{name}")) { if self.defined_names.contains(&format!("{module_name}.{name}")) {
attributes.insert(*name, def_id); match &*self.definition_ast_list[def_id.0].0.read() {
TopLevelDef::Class { .. } | TopLevelDef::Function { .. } => {
methods.insert(*name, def_id);
}
_ => attributes.push((*name, def_id)),
}
} }
}; };
} }
let module_def = TopLevelDef::Module { let module_def = TopLevelDef::Module {
name: module_name.clone().into(), name: module_name.to_string().into(),
module_id: DefinitionId(self.definition_ast_list.len()), module_id: DefinitionId(self.definition_ast_list.len()),
methods,
attributes, attributes,
resolver: Some(resolver), resolver: Some(resolver),
loc: location loc: location,
}; };
self.definition_ast_list.push((Arc::new(RwLock::new(module_def)).into(), None)); self.definition_ast_list.push((Arc::new(RwLock::new(module_def)), None));
Ok(DefinitionId(self.definition_ast_list.len() - 1)) Ok(DefinitionId(self.definition_ast_list.len() - 1))
} }
@ -499,10 +508,10 @@ impl TopLevelComposer {
self.analyze_top_level_class_definition()?; self.analyze_top_level_class_definition()?;
self.analyze_top_level_class_fields_methods()?; self.analyze_top_level_class_fields_methods()?;
self.analyze_top_level_function()?; self.analyze_top_level_function()?;
self.analyze_top_level_variables()?;
if inference { if inference {
self.analyze_function_instance()?; self.analyze_function_instance()?;
} }
self.analyze_top_level_variables()?;
Ok(()) Ok(())
} }
@ -1440,7 +1449,7 @@ impl TopLevelComposer {
Ok(()) Ok(())
} }
/// step 4, analyze and call type inferencer to fill the `instance_to_stmt` of /// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of
/// [`TopLevelDef::Function`] /// [`TopLevelDef::Function`]
fn analyze_function_instance(&mut self) -> Result<(), HashSet<String>> { fn analyze_function_instance(&mut self) -> Result<(), HashSet<String>> {
// first get the class constructor type correct for the following type check in function body // first get the class constructor type correct for the following type check in function body
@ -1971,7 +1980,7 @@ impl TopLevelComposer {
Ok(()) Ok(())
} }
/// Step 5. Analyze and populate the types of global variables. /// Step 4. Analyze and populate the types of global variables.
fn analyze_top_level_variables(&mut self) -> Result<(), HashSet<String>> { fn analyze_top_level_variables(&mut self) -> Result<(), HashSet<String>> {
let def_list = &self.definition_ast_list; let def_list = &self.definition_ast_list;
let temp_def_list = self.extract_def_list(); let temp_def_list = self.extract_def_list();
@ -1989,6 +1998,19 @@ impl TopLevelComposer {
let resolver = &**resolver.as_ref().unwrap(); let resolver = &**resolver.as_ref().unwrap();
if let Some(ty_decl) = ty_decl { if let Some(ty_decl) = ty_decl {
let ty_decl = match &ty_decl.node {
ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node,
ast::ExprKind::Name { id, .. } if self.core_config.kernel_ann.map_or(false, |c| id == &c.into())
) =>
{
slice
}
_ if self.core_config.kernel_ann.is_none() => ty_decl,
_ => unreachable!("Global variables should be annotated with Kernel[]"), // ignore fields annotated otherwise
};
let ty_annotation = parse_ast_to_type_annotation_kinds( let ty_annotation = parse_ast_to_type_annotation_kinds(
resolver, resolver,
&temp_def_list, &temp_def_list,

View File

@ -379,11 +379,12 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef
impl TopLevelDef { impl TopLevelDef {
pub fn to_string(&self, unifier: &mut Unifier) -> String { pub fn to_string(&self, unifier: &mut Unifier) -> String {
match self { match self {
TopLevelDef::Module { name, attributes, .. } => { TopLevelDef::Module { name, attributes, methods, .. } => {
let method_str = attributes.iter().map(|(n, _)| n.to_string()).collect_vec();
format!( format!(
"Module {{\nname: {:?},\nattributes{:?}\n}}", "Module {{\nname: {:?},\nattributes: {:?}\nmethods: {:?}\n}}",
name, method_str name,
attributes.iter().map(|(n, _)| n.to_string()).collect_vec(),
methods.iter().map(|(n, _)| n.to_string()).collect_vec()
) )
} }
TopLevelDef::Class { TopLevelDef::Class {

View File

@ -97,8 +97,10 @@ pub enum TopLevelDef {
name: StrRef, name: StrRef,
/// Module ID used for [`TypeEnum`] /// Module ID used for [`TypeEnum`]
module_id: DefinitionId, module_id: DefinitionId,
/// DefinitionId of `TopLevelDef::{Class, Function, Variable}` within the module /// `DefinitionId` of `TopLevelDef::{Class, Function}` within the module
attributes: HashMap<StrRef, DefinitionId>, methods: HashMap<StrRef, DefinitionId>,
/// `DefinitionId` of `TopLevelDef::{Variable}` within the module
attributes: Vec<(StrRef, DefinitionId)>,
/// Symbol resolver of the module defined the class. /// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>, resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Definition location. /// Definition location.