nac3artiq: implement module type support #573

Merged
sb10q merged 9 commits from abdul124/nac3:issue40 into master 2025-01-17 12:40:41 +08:00
22 changed files with 563 additions and 115 deletions

View File

@ -0,0 +1,29 @@
from min_artiq import *
import tests.string_attribute_issue337 as issue337
import tests.support_class_attr_issue102 as issue102
import tests.global_variables as global_variables
@nac3
class TestModuleSupport:
core: KernelInvariant[Core]
def __init__(self):
self.core = Core()
@kernel
def run(self):
# Accessing classes
issue337.Demo().run()
obj = issue102.Demo()
obj.attr3 = 3
# Calling functions
global_variables.inc_X()
global_variables.display_X()
# Updating global variables
global_variables.X = 9
global_variables.display_X()
if __name__ == "__main__":
TestModuleSupport().run()

View File

@ -0,0 +1,14 @@
from min_artiq import *
from numpy import int32
X: Kernel[int32] = 1
@rpc
def display_X():
print_int32(X)
@kernel
def inc_X():
global X
X += 1

View File

@ -1,16 +1,13 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
core: KernelInvariant[Core]
attr1: KernelInvariant[str]
attr2: KernelInvariant[int32]
attr1: Kernel[str]
attr2: Kernel[int32]
@kernel
def __init__(self):
self.core = Core()
self.attr2 = 32
self.attr1 = "SAMPLE"
@ -19,6 +16,3 @@ class Demo:
print_int32(self.attr2)
self.attr1
if __name__ == "__main__":
Demo().run()

View File

@ -1,7 +1,6 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
attr1: KernelInvariant[int32] = 2
@ -12,7 +11,6 @@ class Demo:
def __init__(self):
self.attr3 = 8
@nac3
class NAC3Devices:
core: KernelInvariant[Core]
@ -35,6 +33,5 @@ class NAC3Devices:
NAC3Devices.attr4 # Attributes accessible for classes without __init__
if __name__ == "__main__":
NAC3Devices().run()

View File

@ -1052,6 +1052,34 @@ pub fn attributes_writeback<'ctx>(
));
}
}
TypeEnum::TModule { attributes, .. } => {
let mut fields = Vec::new();
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
for (name, (field_ty, is_method)) in attributes {
if *is_method {
continue;
}
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
fields.push(name.to_string());
let (index, _) = ctx.get_attr_index(ty, *name);
values.push((
*field_ty,
ctx.build_gep_and_load(
obj.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)],
None,
),
));
}
}
if !fields.is_empty() {
let pydict = PyDict::new(py);
pydict.set_item("obj", val)?;
pydict.set_item("fields", fields)?;
host_attributes.append(pydict)?;
}
}
_ => {}
}
}

View File

@ -43,7 +43,7 @@ use nac3core::{
OptimizationLevel,
},
nac3parser::{
ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
ast::{self, Constant, ExprKind, Located, Stmt, StmtKind, StrRef},
parser::parse_program,
},
symbol_resolver::SymbolResolver,
@ -159,6 +159,7 @@ pub struct PrimitivePythonId {
generic_alias: (u64, u64),
virtual_id: u64,
option: u64,
module: u64,
}
type TopLevelComponent = (Stmt, String, PyObject);
@ -276,6 +277,10 @@ impl Nac3 {
}
})
}
// Allow global variable declaration with `Kernel` type annotation
StmtKind::AnnAssign { ref annotation, .. } => {
matches!(&annotation.node, ExprKind::Subscript { value, .. } if matches!(&value.node, ExprKind::Name {id, ..} if id == &"Kernel".into()))
}
_ => false,
};
@ -469,12 +474,14 @@ impl Nac3 {
];
add_exceptions(&mut composer, &mut builtins_def, &mut builtins_ty, &exception_names);
// Stores a mapping from module id to attributes
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
let mut rpc_ids = vec![];
for (stmt, path, module) in &self.top_levels {
let py_module: &PyAny = module.extract(py)?;
let module_id: u64 = id_fn.call1((py_module,))?.extract()?;
let module_name: String = py_module.getattr("__name__")?.extract()?;
let helper = helper.clone();
let class_obj;
if let StmtKind::ClassDef { name, .. } = &stmt.node {
@ -489,7 +496,7 @@ impl Nac3 {
} else {
class_obj = None;
}
let (name_to_pyid, resolver) =
let (name_to_pyid, resolver, _, _) =
module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| {
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let members: &PyDict =
@ -518,9 +525,17 @@ impl Nac3 {
})))
as Arc<dyn SymbolResolver + Send + Sync>;
let name_to_pyid = Rc::new(name_to_pyid);
module_to_resolver_cache
.insert(module_id, (name_to_pyid.clone(), resolver.clone()));
(name_to_pyid, resolver)
let module_location = ast::Location::new(1, 1, stmt.location.file);
module_to_resolver_cache.insert(
module_id,
(
name_to_pyid.clone(),
resolver.clone(),
module_name.clone(),
Some(module_location),
),
);
(name_to_pyid, resolver, module_name, Some(module_location))
});
let (name, def_id, ty) = composer
@ -594,6 +609,24 @@ impl Nac3 {
}
}
// Adding top level module definitions
for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in
module_to_resolver_cache
{
let def_id = composer
.register_top_level_module(
&module_name,
&module_name_to_pyid,
module_resolver,
module_location,
)
.map_err(|e| {
CompileError::new_err(format!("compilation failed\n----------\n{e}"))
})?;
self.pyid_to_def.write().insert(module_id, def_id);
}
let id_fun = PyModule::import(py, "builtins")?.getattr("id")?;
let mut name_to_pyid: HashMap<StrRef, u64> = HashMap::new();
let module = PyModule::new(py, "tmp")?;
@ -713,6 +746,9 @@ impl Nac3 {
"Unsupported @rpc annotation on global variable",
)))
}
TopLevelDef::Module { .. } => {
unreachable!("Type module cannot be decorated with @rpc")
}
}
}
}
@ -1096,6 +1132,7 @@ impl Nac3 {
tuple: get_attr_id(builtins_mod, "tuple"),
exception: get_attr_id(builtins_mod, "Exception"),
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
module: get_attr_id(types_mod, "ModuleType"),
};
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();

View File

@ -23,7 +23,7 @@ use nac3core::{
inkwell::{
module::Linkage,
types::{BasicType, BasicTypeEnum},
values::BasicValueEnum,
values::{BasicValue, BasicValueEnum},
AddressSpace,
},
nac3parser::ast::{self, StrRef},
@ -674,6 +674,48 @@ impl InnerResolver {
})
});
// check if obj is module
if self.helper.id_fn.call1(py, (ty.clone(),))?.extract::<u64>(py)?
== self.primitive_ids.module
&& self.pyid_to_def.read().contains_key(&py_obj_id)
{
let def_id = self.pyid_to_def.read()[&py_obj_id];
let def = defs[def_id.0].read();
let TopLevelDef::Module { name: module_name, module_id, attributes, methods, .. } =
&*def
else {
unreachable!("must be a module here");
};
// Construct the module return type
let mut module_attributes = HashMap::new();
for (name, _) in attributes {
let attribute_obj = obj.getattr(name.to_string().as_str())?;
let attribute_ty =
self.get_obj_type(py, attribute_obj, unifier, defs, primitives)?;
if let Ok(attribute_ty) = attribute_ty {
module_attributes.insert(*name, (attribute_ty, false));
} else {
return Ok(Err(format!("Unable to resolve {module_name}.{name}")));
}
}
for name in methods.keys() {
let method_obj = obj.getattr(name.to_string().as_str())?;
let method_ty = self.get_obj_type(py, method_obj, unifier, defs, primitives)?;
if let Ok(method_ty) = method_ty {
module_attributes.insert(*name, (method_ty, true));
} else {
return Ok(Err(format!("Unable to resolve {module_name}.{name}")));
}
}
let module_ty =
TypeEnum::TModule { module_id: *module_id, attributes: module_attributes };
let ty = unifier.add_ty(module_ty);
return Ok(Ok(ty));
}
if let Some(ty) = constructor_ty {
self.pyid_to_type.write().insert(py_obj_id, ty);
return Ok(Ok(ty));
@ -1373,6 +1415,77 @@ impl InnerResolver {
None => Ok(None),
}
}
} else if ty_id == self.primitive_ids.module {
let id_str = id.to_string();
if let Some(global) = ctx.module.get_global(&id_str) {
return Ok(Some(global.as_pointer_value().into()));
}
let top_level_defs = ctx.top_level.definitions.read();
let ty = self
.get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)?
.unwrap();
let ty = ctx
.get_llvm_type(generator, ty)
.into_pointer_type()
.get_element_type()
.into_struct_type();
{
if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str)
});
return Ok(Some(global.as_pointer_value().into()));
}
self.global_value_ids.write().insert(id, obj.into());
}
let fields = {
let definition =
top_level_defs.get(self.pyid_to_def.read().get(&id).unwrap().0).unwrap().read();
let TopLevelDef::Module { attributes, .. } = &*definition else { unreachable!() };
attributes
.iter()
.filter_map(|f| {
let definition = top_level_defs.get(f.1 .0).unwrap().read();
if let TopLevelDef::Variable { ty, .. } = &*definition {
Some((f.0, *ty))
} else {
None
}
})
.collect_vec()
};
let values: Result<Option<Vec<_>>, _> = fields
.iter()
.map(|(name, ty)| {
self.get_obj_value(
py,
obj.getattr(name.to_string().as_str())?,
ctx,
generator,
*ty,
)
.map_err(|e| {
super::CompileError::new_err(format!("Error getting field {name}: {e}"))
})
})
.collect();
let values = values?;
if let Some(values) = values {
let val = ty.const_named_struct(&values);
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str)
});
global.set_initializer(&val);
Ok(Some(global.as_pointer_value().into()))
} else {
Ok(None)
}
} else {
let id_str = id.to_string();
@ -1555,9 +1668,50 @@ impl SymbolResolver for Resolver {
fn get_symbol_value<'ctx>(
&self,
id: StrRef,
_: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
ctx: &mut CodeGenContext<'ctx, '_>,
generator: &mut dyn CodeGenerator,
) -> Option<ValueEnum<'ctx>> {
if let Some(def_id) = self.0.id_to_def.read().get(&id) {
let top_levels = ctx.top_level.definitions.read();
if matches!(&*top_levels[def_id.0].read(), TopLevelDef::Variable { .. }) {
let module_val = &self.0.module;
let ret = Python::with_gil(|py| -> PyResult<Result<BasicValueEnum, String>> {
let module_val = module_val.as_ref(py);
let ty = self.0.get_obj_type(
py,
module_val,
&mut ctx.unifier,
&top_levels,
&ctx.primitives,
)?;
if let Err(ty) = ty {
return Ok(Err(ty));
}
let ty = ty.unwrap();
let obj = self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap();
let (idx, _) = ctx.get_attr_index(ty, id);
let ret = unsafe {
ctx.builder.build_gep(
obj.into_pointer_value(),
&[
ctx.ctx.i32_type().const_zero(),
ctx.ctx.i32_type().const_int(idx as u64, false),
],
id.to_string().as_str(),
)
}
.unwrap();
Ok(Ok(ret.as_basic_value_enum()))
})
.unwrap();
if ret.is_err() {
return None;
}
return Some(ret.unwrap().into());
}
}
let sym_value = {
let id_to_val = self.0.id_to_pyval.read();
id_to_val.get(&id).cloned()

View File

@ -56,6 +56,10 @@ pub enum ConcreteTypeEnum {
fields: HashMap<StrRef, (ConcreteType, bool)>,
params: IndexMap<TypeVarId, ConcreteType>,
},
TModule {
module_id: DefinitionId,
methods: HashMap<StrRef, (ConcreteType, bool)>,
},
TVirtual {
ty: ConcreteType,
},
@ -205,6 +209,19 @@ impl ConcreteTypeStore {
})
.collect(),
},
TypeEnum::TModule { module_id, attributes } => ConcreteTypeEnum::TModule {
module_id: *module_id,
methods: attributes
.iter()
.filter_map(|(name, ty)| match &*unifier.get_ty(ty.0) {
TypeEnum::TFunc(..) | TypeEnum::TObj { .. } => None,
_ => Some((
*name,
(self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1),
)),
})
.collect(),
},
TypeEnum::TVirtual { ty } => ConcreteTypeEnum::TVirtual {
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
},
@ -284,6 +301,15 @@ impl ConcreteTypeStore {
TypeVar { id, ty }
})),
},
ConcreteTypeEnum::TModule { module_id, methods } => TypeEnum::TModule {
module_id: *module_id,
attributes: methods
.iter()
.map(|(name, cty)| {
(*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1))
})
.collect::<HashMap<_, _>>(),
},
ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature {
args: args
.iter()

View File

@ -61,8 +61,13 @@ pub fn get_subst_key(
) -> String {
let mut vars = obj
.map(|ty| {
let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() };
params.clone()
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) {
params.clone()
} else if let TypeEnum::TModule { .. } = &*unifier.get_ty(ty) {
indexmap::IndexMap::new()
} else {
unreachable!()
}
})
.unwrap_or_default();
vars.extend(fun_vars);
@ -120,6 +125,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option<Constant>) {
let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id,
TypeEnum::TModule { module_id, .. } => *module_id,
// we cannot have other types, virtual type should be handled by function calls
_ => codegen_unreachable!(self),
};
@ -131,6 +137,8 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap();
(attribute_index.0, Some(attribute_index.1 .2.clone()))
}
} else if let TopLevelDef::Module { attributes, .. } = &*def.read() {
(attributes.iter().find_position(|x| x.0 == attr).unwrap().0, None)
} else {
codegen_unreachable!(self)
};
@ -979,7 +987,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>(
TopLevelDef::Class { .. } => {
return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?))
}
TopLevelDef::Variable { .. } => unreachable!(),
TopLevelDef::Variable { .. } | TopLevelDef::Module { .. } => unreachable!(),
}
}
.or_else(|_: String| {
@ -2805,6 +2813,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
&*ctx.unifier.get_ty(value.custom.unwrap())
{
*obj_id
} else if let TypeEnum::TModule { module_id, .. } =
&*ctx.unifier.get_ty(value.custom.unwrap())
{
*module_id
} else {
codegen_unreachable!(ctx)
};
@ -2815,11 +2827,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} else {
let defs = ctx.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read();
let TopLevelDef::Class { methods, .. } = &*obj_def else {
if let TopLevelDef::Class { methods, .. } = &*obj_def {
methods.iter().find(|method| method.0 == *attr).unwrap().2
} else if let TopLevelDef::Module { methods, .. } = &*obj_def {
*methods.iter().find(|method| method.0 == attr).unwrap().1
} else {
codegen_unreachable!(ctx)
};
methods.iter().find(|method| method.0 == *attr).unwrap().2
}
};
// directly generate code for option.unwrap
// since it needs to return static value to optimize for kernel invariant

View File

@ -501,6 +501,38 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| {
let ty_enum = unifier.get_ty(ty);
let result = match &*ty_enum {
TModule {module_id, attributes} => {
let top_level_defs = top_level.definitions.read();
let definition = top_level_defs.get(module_id.0).unwrap();
let TopLevelDef::Module { name, attributes: attribute_fields, .. } = &*definition.read() else {
unreachable!()
};
let ty: BasicTypeEnum<'_> = if let Some(t) = module.get_struct_type(&name.to_string()) {
t.ptr_type(AddressSpace::default()).into()
} else {
let struct_type = ctx.opaque_struct_type(&name.to_string());
type_cache.insert(
unifier.get_representative(ty),
struct_type.ptr_type(AddressSpace::default()).into(),
);
let module_fields: Vec<BasicTypeEnum<'_>> = attribute_fields.iter()
.map(|f| {
get_llvm_type(
ctx,
module,
generator,
unifier,
top_level,
type_cache,
attributes[&f.0].0,
)
})
.collect_vec();
struct_type.set_body(&module_fields, false);
struct_type.ptr_type(AddressSpace::default()).into()
};
return ty;
},
TObj { obj_id, fields, .. } => {
// check to avoid treating non-class primitives as classes
if PrimDef::contains_id(*obj_id) {

View File

@ -598,10 +598,12 @@ impl dyn SymbolResolver + Send + Sync {
unifier.internal_stringify(
ty,
&mut |id| {
let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else {
unreachable!("expected class definition")
let top_level_def = &*top_level_defs[id].read();
let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) =
top_level_def
else {
unreachable!("expected class/module definition")
};
name.to_string()
},
&mut |id| format!("typevar{id}"),

View File

@ -101,7 +101,9 @@ impl TopLevelComposer {
let builtin_name_list = definition_ast_list
.iter()
.map(|def_ast| match *def_ast.0.read() {
TopLevelDef::Class { name, .. } => name.to_string(),
TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => {
name.to_string()
}
TopLevelDef::Function { simple_name, .. }
| TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(),
})
@ -201,6 +203,43 @@ impl TopLevelComposer {
self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec()
}
/// register top level modules
pub fn register_top_level_module(
&mut self,
module_name: &str,
name_to_pyid: &Rc<HashMap<StrRef, u64>>,
resolver: Arc<dyn SymbolResolver + Send + Sync>,
location: Option<Location>,
) -> Result<DefinitionId, String> {
let mut methods: HashMap<StrRef, DefinitionId> = HashMap::new();
let mut attributes: Vec<(StrRef, DefinitionId)> = Vec::new();
for (name, _) in name_to_pyid.iter() {
if let Ok(def_id) = resolver.get_identifier_def(*name) {
// Avoid repeated attribute instances resulting from multiple imports of same module
if self.defined_names.contains(&format!("{module_name}.{name}")) {
match &*self.definition_ast_list[def_id.0].0.read() {
TopLevelDef::Class { .. } | TopLevelDef::Function { .. } => {
methods.insert(*name, def_id);
}
_ => attributes.push((*name, def_id)),
}
}
};
}
let module_def = TopLevelDef::Module {
name: module_name.to_string().into(),
module_id: DefinitionId(self.definition_ast_list.len()),
methods,
attributes,
resolver: Some(resolver),
loc: location,
};
self.definition_ast_list.push((Arc::new(RwLock::new(module_def)), None));
Ok(DefinitionId(self.definition_ast_list.len() - 1))
}
/// register, just remember the names of top level classes/function
/// and check duplicate class/method/function definition
pub fn register_top_level(
@ -469,10 +508,10 @@ impl TopLevelComposer {
self.analyze_top_level_class_definition()?;
self.analyze_top_level_class_fields_methods()?;
self.analyze_top_level_function()?;
self.analyze_top_level_variables()?;
if inference {
self.analyze_function_instance()?;
}
self.analyze_top_level_variables()?;
Ok(())
}
@ -1410,7 +1449,7 @@ impl TopLevelComposer {
Ok(())
}
/// step 4, analyze and call type inferencer to fill the `instance_to_stmt` of
/// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of
/// [`TopLevelDef::Function`]
fn analyze_function_instance(&mut self) -> Result<(), HashSet<String>> {
// first get the class constructor type correct for the following type check in function body
@ -1941,7 +1980,7 @@ impl TopLevelComposer {
Ok(())
}
/// Step 5. Analyze and populate the types of global variables.
/// Step 4. Analyze and populate the types of global variables.
fn analyze_top_level_variables(&mut self) -> Result<(), HashSet<String>> {
let def_list = &self.definition_ast_list;
let temp_def_list = self.extract_def_list();
@ -1959,6 +1998,19 @@ impl TopLevelComposer {
let resolver = &**resolver.as_ref().unwrap();
if let Some(ty_decl) = ty_decl {
let ty_decl = match &ty_decl.node {
ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node,
ast::ExprKind::Name { id, .. } if self.core_config.kernel_ann.map_or(false, |c| id == &c.into())
) =>
{
slice
}
_ if self.core_config.kernel_ann.is_none() => ty_decl,
_ => unreachable!("Global variables should be annotated with Kernel[]"), // ignore fields annotated otherwise
};
let ty_annotation = parse_ast_to_type_annotation_kinds(
resolver,
&temp_def_list,

View File

@ -379,21 +379,37 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef
impl TopLevelDef {
pub fn to_string(&self, unifier: &mut Unifier) -> String {
match self {
TopLevelDef::Class { name, ancestors, fields, methods, type_vars, .. } => {
TopLevelDef::Module { name, attributes, methods, .. } => {
format!(
"Module {{\nname: {:?},\nattributes: {:?}\nmethods: {:?}\n}}",
name,
attributes.iter().map(|(n, _)| n.to_string()).collect_vec(),
methods.iter().map(|(n, _)| n.to_string()).collect_vec()
)
}
TopLevelDef::Class {
name, ancestors, fields, methods, attributes, type_vars, ..
} => {
let fields_str = fields
.iter()
.map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty)))
.collect_vec();
let attributes_str = attributes
.iter()
.map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty)))
.collect_vec();
let methods_str = methods
.iter()
.map(|(n, ty, id)| (n.to_string(), unifier.stringify(*ty), *id))
.collect_vec();
format!(
"Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}",
"Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nattributes: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}",
name,
ancestors.iter().map(|ancestor| ancestor.stringify(unifier)).collect_vec(),
fields_str.iter().map(|(a, _)| a).collect_vec(),
attributes_str.iter().map(|(a, _)| a).collect_vec(),
methods_str.iter().map(|(a, b, _)| (a, b)).collect_vec(),
type_vars.iter().map(|id| unifier.stringify(*id)).collect_vec(),
)

View File

@ -92,6 +92,20 @@ pub struct FunInstance {
#[derive(Debug, Clone)]
pub enum TopLevelDef {
Module {
/// Name of the module
name: StrRef,
/// Module ID used for [`TypeEnum`]
module_id: DefinitionId,
/// `DefinitionId` of `TopLevelDef::{Class, Function}` within the module
methods: HashMap<StrRef, DefinitionId>,
/// `DefinitionId` of `TopLevelDef::{Variable}` within the module
attributes: Vec<(StrRef, DefinitionId)>,
/// Symbol resolver of the module defined the class.
resolver: Option<Arc<dyn SymbolResolver + Send + Sync>>,
/// Definition location.
loc: Option<Location>,
},
Class {
/// Name for error messages and symbols.
name: StrRef,

View File

@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n",
]

View File

@ -3,13 +3,13 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
]

View File

@ -4,10 +4,10 @@ expression: res_vec
---
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
]

View File

@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n",

View File

@ -3,14 +3,14 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n",

View File

@ -1,9 +1,7 @@
---
source: nac3core/src/toplevel/test.rs
assertion_line: 549
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nmethods: [],\ntype_vars: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nattributes: [],\nmethods: [],\ntype_vars: []\n}\n",
]

View File

@ -2008,72 +2008,90 @@ impl Inferencer<'_> {
ctx: ExprContext,
) -> InferenceResult {
let ty = value.custom.unwrap();
if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) {
// just a fast path
match (fields.get(&attr), ctx == ExprContext::Store) {
(Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty),
(Some((ty, false)), true) => report_type_error(
TypeErrorKind::MutationError(RecordKey::Str(attr), *ty),
Some(value.location),
self.unifier,
),
(None, mutable) => {
// Check whether it is a class attribute
let defs = self.top_level.definitions.read();
let result = {
if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() {
attributes.iter().find_map(|f| {
if f.0 == attr {
return Some(f.1);
}
None
})
} else {
None
}
};
match result {
Some(res) if !mutable => Ok(res),
Some(_) => report_error(
&format!("Class Attribute `{attr}` is immutable"),
value.location,
),
None => report_type_error(
TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty),
Some(value.location),
self.unifier,
),
}
}
}
} else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) {
// Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1
let result = {
self.top_level.definitions.read().iter().find_map(|def| {
if let Some(rear_guard) = def.try_read() {
if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard {
if name.to_string() == self.unifier.stringify(sign.ret) {
return attributes.iter().find_map(|f| {
match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, fields, .. } => {
// just a fast path
match (fields.get(&attr), ctx == ExprContext::Store) {
(Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty),
(Some((ty, false)), true) => report_type_error(
TypeErrorKind::MutationError(RecordKey::Str(attr), *ty),
Some(value.location),
self.unifier,
),
(None, mutable) => {
// Check whether it is a class attribute
let defs = self.top_level.definitions.read();
let result = {
if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() {
attributes.iter().find_map(|f| {
if f.0 == attr {
return Some(f.clone().1);
return Some(f.1);
}
None
});
})
} else {
None
}
};
match result {
Some(res) if !mutable => Ok(res),
Some(_) => report_error(
&format!("Class Attribute `{attr}` is immutable"),
value.location,
),
None => report_type_error(
TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty),
Some(value.location),
self.unifier,
),
}
}
None
})
};
match result {
Some(f) if ctx != ExprContext::Store => Ok(f),
Some(_) => {
report_error(&format!("Class Attribute `{attr}` is immutable"), value.location)
}
None => self.infer_general_attribute(value, attr, ctx),
}
} else {
self.infer_general_attribute(value, attr, ctx)
TypeEnum::TFunc(sign) => {
// Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1
let result = {
self.top_level.definitions.read().iter().find_map(|def| {
if let Some(rear_guard) = def.try_read() {
if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard {
if name.to_string() == self.unifier.stringify(sign.ret) {
return attributes.iter().find_map(|f| {
if f.0 == attr {
return Some(f.clone().1);
}
None
});
}
}
}
None
})
};
match result {
Some(f) if ctx != ExprContext::Store => Ok(f),
Some(_) => report_error(
&format!("Class Attribute `{attr}` is immutable"),
value.location,
),
None => self.infer_general_attribute(value, attr, ctx),
}
}
TypeEnum::TModule { attributes, .. } => {
match (attributes.get(&attr), ctx == ExprContext::Load) {
(Some((ty, _)), true) | (Some((ty, false)), false) => Ok(*ty),
(Some((ty, true)), false) => report_type_error(
TypeErrorKind::MutationError(RecordKey::Str(attr), *ty),
Some(value.location),
self.unifier,
),
(None, _) => report_type_error(
TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty),
Some(value.location),
self.unifier,
),
}
}
_ => self.infer_general_attribute(value, attr, ctx),
}
}
@ -2734,7 +2752,7 @@ impl Inferencer<'_> {
.read()
.iter()
.map(|def| match *def.read() {
TopLevelDef::Class { name, .. } => (name, false),
TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => (name, false),
TopLevelDef::Function { simple_name, .. } => (simple_name, false),
TopLevelDef::Variable { simple_name, .. } => (simple_name, true),
})

View File

@ -270,6 +270,19 @@ pub enum TypeEnum {
/// A function type.
TFunc(FunSignature),
/// Module Type
TModule {
/// The [`DefinitionId`] of this object type.
module_id: DefinitionId,
/// The attributes present in this object type.
///
/// The key of the [Mapping] is the identifier of the field, while the value is a tuple
/// containing the [Type] of the field, and a `bool` indicating whether the field is a
/// variable (as opposed to a function).
attributes: Mapping<StrRef, (Type, bool)>,
},
}
impl TypeEnum {
@ -284,6 +297,7 @@ impl TypeEnum {
TypeEnum::TVirtual { .. } => "TVirtual",
TypeEnum::TCall { .. } => "TCall",
TypeEnum::TFunc { .. } => "TFunc",
TypeEnum::TModule { .. } => "TModule",
}
}
}
@ -593,7 +607,8 @@ impl Unifier {
| TLiteral { .. }
// functions are instantiated for each call sites, so the function type can contain
// type variables.
| TFunc { .. } => true,
| TFunc { .. }
| TModule { .. } => true,
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false,
@ -1315,10 +1330,12 @@ impl Unifier {
|| format!("{id}"),
|top_level| {
let top_level_def = &top_level.definitions.read()[id];
let TopLevelDef::Class { name, .. } = &*top_level_def.read() else {
unreachable!("expected class definition")
let top_level_def = top_level_def.read();
let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) =
&*top_level_def
else {
unreachable!("expected module/class definition")
};
name.to_string()
},
)
@ -1446,6 +1463,10 @@ impl Unifier {
let ret = self.internal_stringify(signature.ret, obj_to_name, var_to_name, notes);
format!("fn[[{params}], {ret}]")
}
TypeEnum::TModule { module_id, .. } => {
let name = obj_to_name(module_id.0);
name.to_string()
}
}
}
@ -1521,7 +1542,9 @@ impl Unifier {
// variables, i.e. things like TRecord, TCall should not occur, and we
// should be safe to not implement the substitution for those variants.
match &*ty {
TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None,
TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } | TypeEnum::TModule { .. } => {
None
}
TypeEnum::TVar { id, .. } => mapping.get(id).copied(),
TypeEnum::TTuple { ty, is_vararg_ctx } => {
let mut new_ty = Cow::from(ty);