1
0
forked from M-Labs/nac3

[core] implement codegen for modules

This commit is contained in:
abdul124 2025-01-16 11:11:53 +08:00
parent 32f24261f2
commit 5fdbc34b43
5 changed files with 98 additions and 9 deletions

View File

@ -1052,6 +1052,34 @@ pub fn attributes_writeback<'ctx>(
)); ));
} }
} }
TypeEnum::TModule { attributes, .. } => {
let mut fields = Vec::new();
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
for (name, (field_ty, is_method)) in attributes {
if *is_method {
continue;
}
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
fields.push(name.to_string());
let (index, _) = ctx.get_attr_index(ty, *name);
values.push((
*field_ty,
ctx.build_gep_and_load(
obj.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)],
None,
),
));
}
}
if !fields.is_empty() {
let pydict = PyDict::new(py);
pydict.set_item("obj", val)?;
pydict.set_item("fields", fields)?;
host_attributes.append(pydict)?;
}
}
_ => {} _ => {}
} }
} }

View File

@ -56,6 +56,10 @@ pub enum ConcreteTypeEnum {
fields: HashMap<StrRef, (ConcreteType, bool)>, fields: HashMap<StrRef, (ConcreteType, bool)>,
params: IndexMap<TypeVarId, ConcreteType>, params: IndexMap<TypeVarId, ConcreteType>,
}, },
TModule {
module_id: DefinitionId,
methods: HashMap<StrRef, (ConcreteType, bool)>,
},
TVirtual { TVirtual {
ty: ConcreteType, ty: ConcreteType,
}, },
@ -297,6 +301,15 @@ impl ConcreteTypeStore {
TypeVar { id, ty } TypeVar { id, ty }
})), })),
}, },
ConcreteTypeEnum::TModule { module_id, methods } => TypeEnum::TModule {
module_id: *module_id,
attributes: methods
.iter()
.map(|(name, cty)| {
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
})
.collect::<HashMap<_, _>>(),
},
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
args: args args: args
.iter() .iter()

View File

@ -61,8 +61,13 @@ pub fn get_subst_key(
) -> String { ) -> String {
let mut vars = obj let mut vars = obj
.map(|ty| { .map(|ty| {
let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() }; if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) {
params.clone() params.clone()
} else if let TypeEnum::TModule { .. } = &*unifier.get_ty(ty) {
indexmap::IndexMap::new()
} else {
unreachable!()
}
}) })
.unwrap_or_default(); .unwrap_or_default();
vars.extend(fun_vars); vars.extend(fun_vars);
@ -120,6 +125,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option<Constant>) { pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option<Constant>) {
let obj_id = match &*self.unifier.get_ty(ty) { let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id, TypeEnum::TObj { obj_id, .. } => *obj_id,
TypeEnum::TModule { module_id, .. } => *module_id,
// we cannot have other types, virtual type should be handled by function calls // we cannot have other types, virtual type should be handled by function calls
_ => codegen_unreachable!(self), _ => codegen_unreachable!(self),
}; };
@ -131,6 +137,8 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap(); let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap();
(attribute_index.0, Some(attribute_index.1 .2.clone())) (attribute_index.0, Some(attribute_index.1 .2.clone()))
} }
} else if let TopLevelDef::Module { attributes, .. } = &*def.read() {
(attributes.iter().find_position(|x| x.0 == attr).unwrap().0, None)
} else { } else {
codegen_unreachable!(self) codegen_unreachable!(self)
}; };
@ -2805,6 +2813,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
&*ctx.unifier.get_ty(value.custom.unwrap()) &*ctx.unifier.get_ty(value.custom.unwrap())
{ {
*obj_id *obj_id
} else if let TypeEnum::TModule { module_id, .. } =
&*ctx.unifier.get_ty(value.custom.unwrap())
{
*module_id
} else { } else {
codegen_unreachable!(ctx) codegen_unreachable!(ctx)
}; };
@ -2815,11 +2827,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} else { } else {
let defs = ctx.top_level.definitions.read(); let defs = ctx.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read(); let obj_def = defs.get(id.0).unwrap().read();
let TopLevelDef::Class { methods, .. } = &*obj_def else { if let TopLevelDef::Class { methods, .. } = &*obj_def {
codegen_unreachable!(ctx)
};
methods.iter().find(|method| method.0 == *attr).unwrap().2 methods.iter().find(|method| method.0 == *attr).unwrap().2
} else if let TopLevelDef::Module { methods, .. } = &*obj_def {
*methods.iter().find(|method| method.0 == attr).unwrap().1
} else {
codegen_unreachable!(ctx)
}
}; };
// directly generate code for option.unwrap // directly generate code for option.unwrap
// since it needs to return static value to optimize for kernel invariant // since it needs to return static value to optimize for kernel invariant

View File

@ -501,6 +501,38 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| { type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| {
let ty_enum = unifier.get_ty(ty); let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum { let result = match &*ty_enum {
TModule {module_id, attributes} => {
let top_level_defs = top_level.definitions.read();
let definition = top_level_defs.get(module_id.0).unwrap();
let TopLevelDef::Module { name, attributes: attribute_fields, .. } = &*definition.read() else {
unreachable!()
};
let ty: BasicTypeEnum<'_> = if let Some(t) = module.get_struct_type(&name.to_string()) {
t.ptr_type(AddressSpace::default()).into()
} else {
let struct_type = ctx.opaque_struct_type(&name.to_string());
type_cache.insert(
unifier.get_representative(ty),
struct_type.ptr_type(AddressSpace::default()).into(),
);
let module_fields: Vec<BasicTypeEnum<'_>> = attribute_fields.iter()
.map(|f| {
get_llvm_type(
ctx,
module,
generator,
unifier,
top_level,
type_cache,
attributes[&f.0].0,
)
})
.collect_vec();
struct_type.set_body(&module_fields, false);
struct_type.ptr_type(AddressSpace::default()).into()
};
return ty;
},
TObj { obj_id, fields, .. } => { TObj { obj_id, fields, .. } => {
// check to avoid treating non-class primitives as classes // check to avoid treating non-class primitives as classes
if PrimDef::contains_id(*obj_id) { if PrimDef::contains_id(*obj_id) {

View File

@ -598,10 +598,12 @@ impl dyn SymbolResolver + Send + Sync {
unifier.internal_stringify( unifier.internal_stringify(
ty, ty,
&mut |id| { &mut |id| {
let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else { let top_level_def = &*top_level_defs[id].read();
unreachable!("expected class definition") let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) =
top_level_def
else {
unreachable!("expected class/module definition")
}; };
name.to_string() name.to_string()
}, },
&mut |id| format!("typevar{id}"), &mut |id| format!("typevar{id}"),