From 5fdbc34b430bd5875623eb5a0e0839d99422555d Mon Sep 17 00:00:00 2001 From: abdul124 Date: Thu, 16 Jan 2025 11:11:53 +0800 Subject: [PATCH] [core] implement codegen for modules --- nac3artiq/src/codegen.rs | 28 +++++++++++++++++++++++ nac3core/src/codegen/concrete_type.rs | 13 +++++++++++ nac3core/src/codegen/expr.rs | 26 +++++++++++++++++----- nac3core/src/codegen/mod.rs | 32 +++++++++++++++++++++++++++ nac3core/src/symbol_resolver.rs | 8 ++++--- 5 files changed, 98 insertions(+), 9 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index c968198b..e4727056 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -1052,6 +1052,34 @@ pub fn attributes_writeback<'ctx>( )); } } + TypeEnum::TModule { attributes, .. } => { + let mut fields = Vec::new(); + let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); + + for (name, (field_ty, is_method)) in attributes { + if *is_method { + continue; + } + if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { + fields.push(name.to_string()); + let (index, _) = ctx.get_attr_index(ty, *name); + values.push(( + *field_ty, + ctx.build_gep_and_load( + obj.into_pointer_value(), + &[zero, int32.const_int(index as u64, false)], + None, + ), + )); + } + } + if !fields.is_empty() { + let pydict = PyDict::new(py); + pydict.set_item("obj", val)?; + pydict.set_item("fields", fields)?; + host_attributes.append(pydict)?; + } + } _ => {} } } diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index f0c92ed8..503a4ae8 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -56,6 +56,10 @@ pub enum ConcreteTypeEnum { fields: HashMap, params: IndexMap, }, + TModule { + module_id: DefinitionId, + methods: HashMap, + }, TVirtual { ty: ConcreteType, }, @@ -297,6 +301,15 @@ impl ConcreteTypeStore { TypeVar { id, ty } })), }, + ConcreteTypeEnum::TModule { module_id, methods } => TypeEnum::TModule { + module_id: *module_id, + attributes: methods + .iter() + .map(|(name, cty)| { + (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) + }) + .collect::>(), + }, ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { args: args .iter() diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 6d2057e1..fd2cd286 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -61,8 +61,13 @@ pub fn get_subst_key( ) -> String { let mut vars = obj .map(|ty| { - let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() }; - params.clone() + if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) { + params.clone() + } else if let TypeEnum::TModule { .. } = &*unifier.get_ty(ty) { + indexmap::IndexMap::new() + } else { + unreachable!() + } }) .unwrap_or_default(); vars.extend(fun_vars); @@ -120,6 +125,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> { pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option) { let obj_id = match &*self.unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } => *obj_id, + TypeEnum::TModule { module_id, .. } => *module_id, // we cannot have other types, virtual type should be handled by function calls _ => codegen_unreachable!(self), }; @@ -131,6 +137,8 @@ impl<'ctx> CodeGenContext<'ctx, '_> { let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap(); (attribute_index.0, Some(attribute_index.1 .2.clone())) } + } else if let TopLevelDef::Module { attributes, .. } = &*def.read() { + (attributes.iter().find_position(|x| x.0 == attr).unwrap().0, None) } else { codegen_unreachable!(self) }; @@ -2805,6 +2813,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( &*ctx.unifier.get_ty(value.custom.unwrap()) { *obj_id + } else if let TypeEnum::TModule { module_id, .. } = + &*ctx.unifier.get_ty(value.custom.unwrap()) + { + *module_id } else { codegen_unreachable!(ctx) }; @@ -2815,11 +2827,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); - let TopLevelDef::Class { methods, .. } = &*obj_def else { + if let TopLevelDef::Class { methods, .. } = &*obj_def { + methods.iter().find(|method| method.0 == *attr).unwrap().2 + } else if let TopLevelDef::Module { methods, .. } = &*obj_def { + *methods.iter().find(|method| method.0 == attr).unwrap().1 + } else { codegen_unreachable!(ctx) - }; - - methods.iter().find(|method| method.0 == *attr).unwrap().2 + } }; // directly generate code for option.unwrap // since it needs to return static value to optimize for kernel invariant diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index dcfa2b8c..37e1bb33 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -501,6 +501,38 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| { let ty_enum = unifier.get_ty(ty); let result = match &*ty_enum { + TModule {module_id, attributes} => { + let top_level_defs = top_level.definitions.read(); + let definition = top_level_defs.get(module_id.0).unwrap(); + let TopLevelDef::Module { name, attributes: attribute_fields, .. } = &*definition.read() else { + unreachable!() + }; + let ty: BasicTypeEnum<'_> = if let Some(t) = module.get_struct_type(&name.to_string()) { + t.ptr_type(AddressSpace::default()).into() + } else { + let struct_type = ctx.opaque_struct_type(&name.to_string()); + type_cache.insert( + unifier.get_representative(ty), + struct_type.ptr_type(AddressSpace::default()).into(), + ); + let module_fields: Vec> = attribute_fields.iter() + .map(|f| { + get_llvm_type( + ctx, + module, + generator, + unifier, + top_level, + type_cache, + attributes[&f.0].0, + ) + }) + .collect_vec(); + struct_type.set_body(&module_fields, false); + struct_type.ptr_type(AddressSpace::default()).into() + }; + return ty; + }, TObj { obj_id, fields, .. } => { // check to avoid treating non-class primitives as classes if PrimDef::contains_id(*obj_id) { diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 2378dd62..48290935 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -598,10 +598,12 @@ impl dyn SymbolResolver + Send + Sync { unifier.internal_stringify( ty, &mut |id| { - let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else { - unreachable!("expected class definition") + let top_level_def = &*top_level_defs[id].read(); + let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) = + top_level_def + else { + unreachable!("expected class/module definition") }; - name.to_string() }, &mut |id| format!("typevar{id}"),