diff --git a/nac3artiq/demo/module_support.py b/nac3artiq/demo/module_support.py new file mode 100644 index 00000000..a863b380 --- /dev/null +++ b/nac3artiq/demo/module_support.py @@ -0,0 +1,29 @@ +from min_artiq import * +import tests.string_attribute_issue337 as issue337 +import tests.support_class_attr_issue102 as issue102 +import tests.global_variables as global_variables + +@nac3 +class TestModuleSupport: + core: KernelInvariant[Core] + + def __init__(self): + self.core = Core() + + @kernel + def run(self): + # Accessing classes + issue337.Demo().run() + obj = issue102.Demo() + obj.attr3 = 3 + + # Calling functions + global_variables.inc_X() + global_variables.display_X() + + # Updating global variables + global_variables.X = 9 + global_variables.display_X() + +if __name__ == "__main__": + TestModuleSupport().run() \ No newline at end of file diff --git a/nac3artiq/demo/tests/global_variables.py b/nac3artiq/demo/tests/global_variables.py new file mode 100644 index 00000000..ac0e0cf0 --- /dev/null +++ b/nac3artiq/demo/tests/global_variables.py @@ -0,0 +1,14 @@ +from min_artiq import * +from numpy import int32 + +X: Kernel[int32] = 1 + +@rpc +def display_X(): + print_int32(X) + +@kernel +def inc_X(): + global X + X += 1 + diff --git a/nac3artiq/demo/string_attribute_issue337.py b/nac3artiq/demo/tests/string_attribute_issue337.py similarity index 57% rename from nac3artiq/demo/string_attribute_issue337.py rename to nac3artiq/demo/tests/string_attribute_issue337.py index 9749462a..c0b36ed6 100644 --- a/nac3artiq/demo/string_attribute_issue337.py +++ b/nac3artiq/demo/tests/string_attribute_issue337.py @@ -1,16 +1,13 @@ from min_artiq import * from numpy import int32 - @nac3 class Demo: - core: KernelInvariant[Core] - attr1: KernelInvariant[str] - attr2: KernelInvariant[int32] - + attr1: Kernel[str] + attr2: Kernel[int32] + @kernel def __init__(self): - self.core = Core() self.attr2 = 32 self.attr1 = "SAMPLE" @@ -19,6 +16,3 @@ class Demo: print_int32(self.attr2) self.attr1 - -if __name__ == "__main__": - Demo().run() diff --git a/nac3artiq/demo/support_class_attr_issue102.py b/nac3artiq/demo/tests/support_class_attr_issue102.py similarity index 99% rename from nac3artiq/demo/support_class_attr_issue102.py rename to nac3artiq/demo/tests/support_class_attr_issue102.py index 1b931444..0482e3f1 100644 --- a/nac3artiq/demo/support_class_attr_issue102.py +++ b/nac3artiq/demo/tests/support_class_attr_issue102.py @@ -1,7 +1,6 @@ from min_artiq import * from numpy import int32 - @nac3 class Demo: attr1: KernelInvariant[int32] = 2 @@ -12,7 +11,6 @@ class Demo: def __init__(self): self.attr3 = 8 - @nac3 class NAC3Devices: core: KernelInvariant[Core] @@ -35,6 +33,5 @@ class NAC3Devices: NAC3Devices.attr4 # Attributes accessible for classes without __init__ - if __name__ == "__main__": NAC3Devices().run() diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index c968198b..e4727056 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -1052,6 +1052,34 @@ pub fn attributes_writeback<'ctx>( )); } } + TypeEnum::TModule { attributes, .. } => { + let mut fields = Vec::new(); + let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); + + for (name, (field_ty, is_method)) in attributes { + if *is_method { + continue; + } + if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { + fields.push(name.to_string()); + let (index, _) = ctx.get_attr_index(ty, *name); + values.push(( + *field_ty, + ctx.build_gep_and_load( + obj.into_pointer_value(), + &[zero, int32.const_int(index as u64, false)], + None, + ), + )); + } + } + if !fields.is_empty() { + let pydict = PyDict::new(py); + pydict.set_item("obj", val)?; + pydict.set_item("fields", fields)?; + host_attributes.append(pydict)?; + } + } _ => {} } } diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index d35e66d1..ba6c4fae 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -43,7 +43,7 @@ use nac3core::{ OptimizationLevel, }, nac3parser::{ - ast::{Constant, ExprKind, Located, Stmt, StmtKind, StrRef}, + ast::{self, Constant, ExprKind, Located, Stmt, StmtKind, StrRef}, parser::parse_program, }, symbol_resolver::SymbolResolver, @@ -159,6 +159,7 @@ pub struct PrimitivePythonId { generic_alias: (u64, u64), virtual_id: u64, option: u64, + module: u64, } type TopLevelComponent = (Stmt, String, PyObject); @@ -276,6 +277,10 @@ impl Nac3 { } }) } + // Allow global variable declaration with `Kernel` type annotation + StmtKind::AnnAssign { ref annotation, .. } => { + matches!(&annotation.node, ExprKind::Subscript { value, .. } if matches!(&value.node, ExprKind::Name {id, ..} if id == &"Kernel".into())) + } _ => false, }; @@ -469,12 +474,14 @@ impl Nac3 { ]; add_exceptions(&mut composer, &mut builtins_def, &mut builtins_ty, &exception_names); + // Stores a mapping from module id to attributes let mut module_to_resolver_cache: HashMap = HashMap::new(); let mut rpc_ids = vec![]; for (stmt, path, module) in &self.top_levels { let py_module: &PyAny = module.extract(py)?; let module_id: u64 = id_fn.call1((py_module,))?.extract()?; + let module_name: String = py_module.getattr("__name__")?.extract()?; let helper = helper.clone(); let class_obj; if let StmtKind::ClassDef { name, .. } = &stmt.node { @@ -489,7 +496,7 @@ impl Nac3 { } else { class_obj = None; } - let (name_to_pyid, resolver) = + let (name_to_pyid, resolver, _, _) = module_to_resolver_cache.get(&module_id).cloned().unwrap_or_else(|| { let mut name_to_pyid: HashMap = HashMap::new(); let members: &PyDict = @@ -518,9 +525,17 @@ impl Nac3 { }))) as Arc; let name_to_pyid = Rc::new(name_to_pyid); - module_to_resolver_cache - .insert(module_id, (name_to_pyid.clone(), resolver.clone())); - (name_to_pyid, resolver) + let module_location = ast::Location::new(1, 1, stmt.location.file); + module_to_resolver_cache.insert( + module_id, + ( + name_to_pyid.clone(), + resolver.clone(), + module_name.clone(), + Some(module_location), + ), + ); + (name_to_pyid, resolver, module_name, Some(module_location)) }); let (name, def_id, ty) = composer @@ -594,6 +609,24 @@ impl Nac3 { } } + // Adding top level module definitions + for (module_id, (module_name_to_pyid, module_resolver, module_name, module_location)) in + module_to_resolver_cache + { + let def_id = composer + .register_top_level_module( + &module_name, + &module_name_to_pyid, + module_resolver, + module_location, + ) + .map_err(|e| { + CompileError::new_err(format!("compilation failed\n----------\n{e}")) + })?; + + self.pyid_to_def.write().insert(module_id, def_id); + } + let id_fun = PyModule::import(py, "builtins")?.getattr("id")?; let mut name_to_pyid: HashMap = HashMap::new(); let module = PyModule::new(py, "tmp")?; @@ -713,6 +746,9 @@ impl Nac3 { "Unsupported @rpc annotation on global variable", ))) } + TopLevelDef::Module { .. } => { + unreachable!("Type module cannot be decorated with @rpc") + } } } } @@ -1096,6 +1132,7 @@ impl Nac3 { tuple: get_attr_id(builtins_mod, "tuple"), exception: get_attr_id(builtins_mod, "Exception"), option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()), + module: get_attr_id(types_mod, "ModuleType"), }; let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index d9768669..4b398a9b 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -23,7 +23,7 @@ use nac3core::{ inkwell::{ module::Linkage, types::{BasicType, BasicTypeEnum}, - values::BasicValueEnum, + values::{BasicValue, BasicValueEnum}, AddressSpace, }, nac3parser::ast::{self, StrRef}, @@ -674,6 +674,48 @@ impl InnerResolver { }) }); + // check if obj is module + if self.helper.id_fn.call1(py, (ty.clone(),))?.extract::(py)? + == self.primitive_ids.module + && self.pyid_to_def.read().contains_key(&py_obj_id) + { + let def_id = self.pyid_to_def.read()[&py_obj_id]; + let def = defs[def_id.0].read(); + let TopLevelDef::Module { name: module_name, module_id, attributes, methods, .. } = + &*def + else { + unreachable!("must be a module here"); + }; + // Construct the module return type + let mut module_attributes = HashMap::new(); + for (name, _) in attributes { + let attribute_obj = obj.getattr(name.to_string().as_str())?; + let attribute_ty = + self.get_obj_type(py, attribute_obj, unifier, defs, primitives)?; + if let Ok(attribute_ty) = attribute_ty { + module_attributes.insert(*name, (attribute_ty, false)); + } else { + return Ok(Err(format!("Unable to resolve {module_name}.{name}"))); + } + } + + for name in methods.keys() { + let method_obj = obj.getattr(name.to_string().as_str())?; + let method_ty = self.get_obj_type(py, method_obj, unifier, defs, primitives)?; + if let Ok(method_ty) = method_ty { + module_attributes.insert(*name, (method_ty, true)); + } else { + return Ok(Err(format!("Unable to resolve {module_name}.{name}"))); + } + } + + let module_ty = + TypeEnum::TModule { module_id: *module_id, attributes: module_attributes }; + + let ty = unifier.add_ty(module_ty); + return Ok(Ok(ty)); + } + if let Some(ty) = constructor_ty { self.pyid_to_type.write().insert(py_obj_id, ty); return Ok(Ok(ty)); @@ -1373,6 +1415,77 @@ impl InnerResolver { None => Ok(None), } } + } else if ty_id == self.primitive_ids.module { + let id_str = id.to_string(); + + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + + let top_level_defs = ctx.top_level.definitions.read(); + let ty = self + .get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)? + .unwrap(); + let ty = ctx + .get_llvm_type(generator, ty) + .into_pointer_type() + .get_element_type() + .into_struct_type(); + + { + if self.global_value_ids.read().contains_key(&id) { + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) + }); + return Ok(Some(global.as_pointer_value().into())); + } + self.global_value_ids.write().insert(id, obj.into()); + } + + let fields = { + let definition = + top_level_defs.get(self.pyid_to_def.read().get(&id).unwrap().0).unwrap().read(); + let TopLevelDef::Module { attributes, .. } = &*definition else { unreachable!() }; + attributes + .iter() + .filter_map(|f| { + let definition = top_level_defs.get(f.1 .0).unwrap().read(); + if let TopLevelDef::Variable { ty, .. } = &*definition { + Some((f.0, *ty)) + } else { + None + } + }) + .collect_vec() + }; + + let values: Result>, _> = fields + .iter() + .map(|(name, ty)| { + self.get_obj_value( + py, + obj.getattr(name.to_string().as_str())?, + ctx, + generator, + *ty, + ) + .map_err(|e| { + super::CompileError::new_err(format!("Error getting field {name}: {e}")) + }) + }) + .collect(); + let values = values?; + + if let Some(values) = values { + let val = ty.const_named_struct(&values); + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) + }); + global.set_initializer(&val); + Ok(Some(global.as_pointer_value().into())) + } else { + Ok(None) + } } else { let id_str = id.to_string(); @@ -1555,9 +1668,50 @@ impl SymbolResolver for Resolver { fn get_symbol_value<'ctx>( &self, id: StrRef, - _: &mut CodeGenContext<'ctx, '_>, - _: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, ) -> Option> { + if let Some(def_id) = self.0.id_to_def.read().get(&id) { + let top_levels = ctx.top_level.definitions.read(); + if matches!(&*top_levels[def_id.0].read(), TopLevelDef::Variable { .. }) { + let module_val = &self.0.module; + let ret = Python::with_gil(|py| -> PyResult> { + let module_val = module_val.as_ref(py); + + let ty = self.0.get_obj_type( + py, + module_val, + &mut ctx.unifier, + &top_levels, + &ctx.primitives, + )?; + if let Err(ty) = ty { + return Ok(Err(ty)); + } + let ty = ty.unwrap(); + let obj = self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap(); + let (idx, _) = ctx.get_attr_index(ty, id); + let ret = unsafe { + ctx.builder.build_gep( + obj.into_pointer_value(), + &[ + ctx.ctx.i32_type().const_zero(), + ctx.ctx.i32_type().const_int(idx as u64, false), + ], + id.to_string().as_str(), + ) + } + .unwrap(); + Ok(Ok(ret.as_basic_value_enum())) + }) + .unwrap(); + if ret.is_err() { + return None; + } + return Some(ret.unwrap().into()); + } + } + let sym_value = { let id_to_val = self.0.id_to_pyval.read(); id_to_val.get(&id).cloned() diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index d5c1fc38..503a4ae8 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -56,6 +56,10 @@ pub enum ConcreteTypeEnum { fields: HashMap, params: IndexMap, }, + TModule { + module_id: DefinitionId, + methods: HashMap, + }, TVirtual { ty: ConcreteType, }, @@ -205,6 +209,19 @@ impl ConcreteTypeStore { }) .collect(), }, + TypeEnum::TModule { module_id, attributes } => ConcreteTypeEnum::TModule { + module_id: *module_id, + methods: attributes + .iter() + .filter_map(|(name, ty)| match &*unifier.get_ty(ty.0) { + TypeEnum::TFunc(..) | TypeEnum::TObj { .. } => None, + _ => Some(( + *name, + (self.from_unifier_type(unifier, primitives, ty.0, cache), ty.1), + )), + }) + .collect(), + }, TypeEnum::TVirtual { ty } => ConcreteTypeEnum::TVirtual { ty: self.from_unifier_type(unifier, primitives, *ty, cache), }, @@ -284,6 +301,15 @@ impl ConcreteTypeStore { TypeVar { id, ty } })), }, + ConcreteTypeEnum::TModule { module_id, methods } => TypeEnum::TModule { + module_id: *module_id, + attributes: methods + .iter() + .map(|(name, cty)| { + (*name, (self.to_unifier_type(unifier, primitives, cty.0, cache), cty.1)) + }) + .collect::>(), + }, ConcreteTypeEnum::TFunc { args, ret, vars } => TypeEnum::TFunc(FunSignature { args: args .iter() diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 30b8dcd3..fd2cd286 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -61,8 +61,13 @@ pub fn get_subst_key( ) -> String { let mut vars = obj .map(|ty| { - let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else { unreachable!() }; - params.clone() + if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) { + params.clone() + } else if let TypeEnum::TModule { .. } = &*unifier.get_ty(ty) { + indexmap::IndexMap::new() + } else { + unreachable!() + } }) .unwrap_or_default(); vars.extend(fun_vars); @@ -120,6 +125,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> { pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option) { let obj_id = match &*self.unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } => *obj_id, + TypeEnum::TModule { module_id, .. } => *module_id, // we cannot have other types, virtual type should be handled by function calls _ => codegen_unreachable!(self), }; @@ -131,6 +137,8 @@ impl<'ctx> CodeGenContext<'ctx, '_> { let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap(); (attribute_index.0, Some(attribute_index.1 .2.clone())) } + } else if let TopLevelDef::Module { attributes, .. } = &*def.read() { + (attributes.iter().find_position(|x| x.0 == attr).unwrap().0, None) } else { codegen_unreachable!(self) }; @@ -979,7 +987,7 @@ pub fn gen_call<'ctx, G: CodeGenerator>( TopLevelDef::Class { .. } => { return Ok(Some(generator.gen_constructor(ctx, fun.0, &def, params)?)) } - TopLevelDef::Variable { .. } => unreachable!(), + TopLevelDef::Variable { .. } | TopLevelDef::Module { .. } => unreachable!(), } } .or_else(|_: String| { @@ -2805,6 +2813,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( &*ctx.unifier.get_ty(value.custom.unwrap()) { *obj_id + } else if let TypeEnum::TModule { module_id, .. } = + &*ctx.unifier.get_ty(value.custom.unwrap()) + { + *module_id } else { codegen_unreachable!(ctx) }; @@ -2815,11 +2827,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } else { let defs = ctx.top_level.definitions.read(); let obj_def = defs.get(id.0).unwrap().read(); - let TopLevelDef::Class { methods, .. } = &*obj_def else { + if let TopLevelDef::Class { methods, .. } = &*obj_def { + methods.iter().find(|method| method.0 == *attr).unwrap().2 + } else if let TopLevelDef::Module { methods, .. } = &*obj_def { + *methods.iter().find(|method| method.0 == attr).unwrap().1 + } else { codegen_unreachable!(ctx) - }; - - methods.iter().find(|method| method.0 == *attr).unwrap().2 + } }; // directly generate code for option.unwrap // since it needs to return static value to optimize for kernel invariant diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index dcfa2b8c..37e1bb33 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -501,6 +501,38 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( type_cache.get(&unifier.get_representative(ty)).copied().unwrap_or_else(|| { let ty_enum = unifier.get_ty(ty); let result = match &*ty_enum { + TModule {module_id, attributes} => { + let top_level_defs = top_level.definitions.read(); + let definition = top_level_defs.get(module_id.0).unwrap(); + let TopLevelDef::Module { name, attributes: attribute_fields, .. } = &*definition.read() else { + unreachable!() + }; + let ty: BasicTypeEnum<'_> = if let Some(t) = module.get_struct_type(&name.to_string()) { + t.ptr_type(AddressSpace::default()).into() + } else { + let struct_type = ctx.opaque_struct_type(&name.to_string()); + type_cache.insert( + unifier.get_representative(ty), + struct_type.ptr_type(AddressSpace::default()).into(), + ); + let module_fields: Vec> = attribute_fields.iter() + .map(|f| { + get_llvm_type( + ctx, + module, + generator, + unifier, + top_level, + type_cache, + attributes[&f.0].0, + ) + }) + .collect_vec(); + struct_type.set_body(&module_fields, false); + struct_type.ptr_type(AddressSpace::default()).into() + }; + return ty; + }, TObj { obj_id, fields, .. } => { // check to avoid treating non-class primitives as classes if PrimDef::contains_id(*obj_id) { diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 2378dd62..48290935 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -598,10 +598,12 @@ impl dyn SymbolResolver + Send + Sync { unifier.internal_stringify( ty, &mut |id| { - let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else { - unreachable!("expected class definition") + let top_level_def = &*top_level_defs[id].read(); + let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) = + top_level_def + else { + unreachable!("expected class/module definition") }; - name.to_string() }, &mut |id| format!("typevar{id}"), diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index bd9a9214..a6a0ce76 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -101,7 +101,9 @@ impl TopLevelComposer { let builtin_name_list = definition_ast_list .iter() .map(|def_ast| match *def_ast.0.read() { - TopLevelDef::Class { name, .. } => name.to_string(), + TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => { + name.to_string() + } TopLevelDef::Function { simple_name, .. } | TopLevelDef::Variable { simple_name, .. } => simple_name.to_string(), }) @@ -201,6 +203,43 @@ impl TopLevelComposer { self.definition_ast_list.iter().map(|(def, ..)| def.clone()).collect_vec() } + /// register top level modules + pub fn register_top_level_module( + &mut self, + module_name: &str, + name_to_pyid: &Rc>, + resolver: Arc, + location: Option, + ) -> Result { + let mut methods: HashMap = HashMap::new(); + let mut attributes: Vec<(StrRef, DefinitionId)> = Vec::new(); + + for (name, _) in name_to_pyid.iter() { + if let Ok(def_id) = resolver.get_identifier_def(*name) { + // Avoid repeated attribute instances resulting from multiple imports of same module + if self.defined_names.contains(&format!("{module_name}.{name}")) { + match &*self.definition_ast_list[def_id.0].0.read() { + TopLevelDef::Class { .. } | TopLevelDef::Function { .. } => { + methods.insert(*name, def_id); + } + _ => attributes.push((*name, def_id)), + } + } + }; + } + let module_def = TopLevelDef::Module { + name: module_name.to_string().into(), + module_id: DefinitionId(self.definition_ast_list.len()), + methods, + attributes, + resolver: Some(resolver), + loc: location, + }; + + self.definition_ast_list.push((Arc::new(RwLock::new(module_def)), None)); + Ok(DefinitionId(self.definition_ast_list.len() - 1)) + } + /// register, just remember the names of top level classes/function /// and check duplicate class/method/function definition pub fn register_top_level( @@ -469,10 +508,10 @@ impl TopLevelComposer { self.analyze_top_level_class_definition()?; self.analyze_top_level_class_fields_methods()?; self.analyze_top_level_function()?; + self.analyze_top_level_variables()?; if inference { self.analyze_function_instance()?; } - self.analyze_top_level_variables()?; Ok(()) } @@ -1410,7 +1449,7 @@ impl TopLevelComposer { Ok(()) } - /// step 4, analyze and call type inferencer to fill the `instance_to_stmt` of + /// step 5, analyze and call type inferencer to fill the `instance_to_stmt` of /// [`TopLevelDef::Function`] fn analyze_function_instance(&mut self) -> Result<(), HashSet> { // first get the class constructor type correct for the following type check in function body @@ -1941,7 +1980,7 @@ impl TopLevelComposer { Ok(()) } - /// Step 5. Analyze and populate the types of global variables. + /// Step 4. Analyze and populate the types of global variables. fn analyze_top_level_variables(&mut self) -> Result<(), HashSet> { let def_list = &self.definition_ast_list; let temp_def_list = self.extract_def_list(); @@ -1959,6 +1998,19 @@ impl TopLevelComposer { let resolver = &**resolver.as_ref().unwrap(); if let Some(ty_decl) = ty_decl { + let ty_decl = match &ty_decl.node { + ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if self.core_config.kernel_ann.map_or(false, |c| id == &c.into()) + ) => + { + slice + } + _ if self.core_config.kernel_ann.is_none() => ty_decl, + _ => unreachable!("Global variables should be annotated with Kernel[]"), // ignore fields annotated otherwise + }; + let ty_annotation = parse_ast_to_type_annotation_kinds( resolver, &temp_def_list, diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 72d3eaa6..4ca5464f 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -379,21 +379,37 @@ pub fn make_exception_fields(int32: Type, int64: Type, str: Type) -> Vec<(StrRef impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { - TopLevelDef::Class { name, ancestors, fields, methods, type_vars, .. } => { + TopLevelDef::Module { name, attributes, methods, .. } => { + format!( + "Module {{\nname: {:?},\nattributes: {:?}\nmethods: {:?}\n}}", + name, + attributes.iter().map(|(n, _)| n.to_string()).collect_vec(), + methods.iter().map(|(n, _)| n.to_string()).collect_vec() + ) + } + TopLevelDef::Class { + name, ancestors, fields, methods, attributes, type_vars, .. + } => { let fields_str = fields .iter() .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty))) .collect_vec(); + let attributes_str = attributes + .iter() + .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty))) + .collect_vec(); + let methods_str = methods .iter() .map(|(n, ty, id)| (n.to_string(), unifier.stringify(*ty), *id)) .collect_vec(); format!( - "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", + "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nattributes: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", name, ancestors.iter().map(|ancestor| ancestor.stringify(unifier)).collect_vec(), fields_str.iter().map(|(a, _)| a).collect_vec(), + attributes_str.iter().map(|(a, _)| a).collect_vec(), methods_str.iter().map(|(a, b, _)| (a, b)).collect_vec(), type_vars.iter().map(|id| unifier.stringify(*id)).collect_vec(), ) diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index cba2f5e7..3ffd568a 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -92,6 +92,20 @@ pub struct FunInstance { #[derive(Debug, Clone)] pub enum TopLevelDef { + Module { + /// Name of the module + name: StrRef, + /// Module ID used for [`TypeEnum`] + module_id: DefinitionId, + /// `DefinitionId` of `TopLevelDef::{Class, Function}` within the module + methods: HashMap, + /// `DefinitionId` of `TopLevelDef::{Variable}` within the module + attributes: Vec<(StrRef, DefinitionId)>, + /// Symbol resolver of the module defined the class. + resolver: Option>, + /// Definition location. + loc: Option, + }, Class { /// Name for error messages and symbols. name: StrRef, diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 4332b474..8c827eed 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", + "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(261)]\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 60e0c194..b8a80a5c 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -3,13 +3,13 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[T]\"],\nfields: [\"a\", \"b\", \"c\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[t:T], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"T\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar245]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar245\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", - "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 46601817..05f44884 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -4,10 +4,10 @@ expression: res_vec --- [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", - "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(258)]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(263)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", ] diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index da58d121..7d3922e7 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,10 +3,10 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar244, typevar245]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar244\", \"typevar245\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:B], B]\",\nvar_id: []\n}\n", "Function {\nname: \"B.bar\",\nsig: \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 8f384fa1..b55e9985 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -3,14 +3,14 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(264)]\n}\n", - "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nattributes: [],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(272)]\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap index 5178f1b4..2f37789c 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_pass_in_class.snap @@ -1,9 +1,7 @@ --- source: nac3core/src/toplevel/test.rs -assertion_line: 549 expression: res_vec - --- [ - "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nmethods: [],\ntype_vars: []\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [],\nattributes: [],\nmethods: [],\ntype_vars: []\n}\n", ] diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 742fa197..7ce659f3 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -2008,72 +2008,90 @@ impl Inferencer<'_> { ctx: ExprContext, ) -> InferenceResult { let ty = value.custom.unwrap(); - if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) { - // just a fast path - match (fields.get(&attr), ctx == ExprContext::Store) { - (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), - (Some((ty, false)), true) => report_type_error( - TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), - Some(value.location), - self.unifier, - ), - (None, mutable) => { - // Check whether it is a class attribute - let defs = self.top_level.definitions.read(); - let result = { - if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() { - attributes.iter().find_map(|f| { - if f.0 == attr { - return Some(f.1); - } - None - }) - } else { - None - } - }; - match result { - Some(res) if !mutable => Ok(res), - Some(_) => report_error( - &format!("Class Attribute `{attr}` is immutable"), - value.location, - ), - None => report_type_error( - TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty), - Some(value.location), - self.unifier, - ), - } - } - } - } else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) { - // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 - let result = { - self.top_level.definitions.read().iter().find_map(|def| { - if let Some(rear_guard) = def.try_read() { - if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard { - if name.to_string() == self.unifier.stringify(sign.ret) { - return attributes.iter().find_map(|f| { + match &*self.unifier.get_ty(ty) { + TypeEnum::TObj { obj_id, fields, .. } => { + // just a fast path + match (fields.get(&attr), ctx == ExprContext::Store) { + (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), + (Some((ty, false)), true) => report_type_error( + TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), + Some(value.location), + self.unifier, + ), + (None, mutable) => { + // Check whether it is a class attribute + let defs = self.top_level.definitions.read(); + let result = { + if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() { + attributes.iter().find_map(|f| { if f.0 == attr { - return Some(f.clone().1); + return Some(f.1); } None - }); + }) + } else { + None } + }; + match result { + Some(res) if !mutable => Ok(res), + Some(_) => report_error( + &format!("Class Attribute `{attr}` is immutable"), + value.location, + ), + None => report_type_error( + TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty), + Some(value.location), + self.unifier, + ), } } - None - }) - }; - match result { - Some(f) if ctx != ExprContext::Store => Ok(f), - Some(_) => { - report_error(&format!("Class Attribute `{attr}` is immutable"), value.location) } - None => self.infer_general_attribute(value, attr, ctx), } - } else { - self.infer_general_attribute(value, attr, ctx) + TypeEnum::TFunc(sign) => { + // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 + let result = { + self.top_level.definitions.read().iter().find_map(|def| { + if let Some(rear_guard) = def.try_read() { + if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard { + if name.to_string() == self.unifier.stringify(sign.ret) { + return attributes.iter().find_map(|f| { + if f.0 == attr { + return Some(f.clone().1); + } + None + }); + } + } + } + None + }) + }; + match result { + Some(f) if ctx != ExprContext::Store => Ok(f), + Some(_) => report_error( + &format!("Class Attribute `{attr}` is immutable"), + value.location, + ), + None => self.infer_general_attribute(value, attr, ctx), + } + } + TypeEnum::TModule { attributes, .. } => { + match (attributes.get(&attr), ctx == ExprContext::Load) { + (Some((ty, _)), true) | (Some((ty, false)), false) => Ok(*ty), + (Some((ty, true)), false) => report_type_error( + TypeErrorKind::MutationError(RecordKey::Str(attr), *ty), + Some(value.location), + self.unifier, + ), + (None, _) => report_type_error( + TypeErrorKind::NoSuchField(RecordKey::Str(attr), ty), + Some(value.location), + self.unifier, + ), + } + } + _ => self.infer_general_attribute(value, attr, ctx), } } @@ -2734,7 +2752,7 @@ impl Inferencer<'_> { .read() .iter() .map(|def| match *def.read() { - TopLevelDef::Class { name, .. } => (name, false), + TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. } => (name, false), TopLevelDef::Function { simple_name, .. } => (simple_name, false), TopLevelDef::Variable { simple_name, .. } => (simple_name, true), }) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index e190c4c4..f2f9ed6f 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -270,6 +270,19 @@ pub enum TypeEnum { /// A function type. TFunc(FunSignature), + + /// Module Type + TModule { + /// The [`DefinitionId`] of this object type. + module_id: DefinitionId, + + /// The attributes present in this object type. + /// + /// The key of the [Mapping] is the identifier of the field, while the value is a tuple + /// containing the [Type] of the field, and a `bool` indicating whether the field is a + /// variable (as opposed to a function). + attributes: Mapping, + }, } impl TypeEnum { @@ -284,6 +297,7 @@ impl TypeEnum { TypeEnum::TVirtual { .. } => "TVirtual", TypeEnum::TCall { .. } => "TCall", TypeEnum::TFunc { .. } => "TFunc", + TypeEnum::TModule { .. } => "TModule", } } } @@ -593,7 +607,8 @@ impl Unifier { | TLiteral { .. } // functions are instantiated for each call sites, so the function type can contain // type variables. - | TFunc { .. } => true, + | TFunc { .. } + | TModule { .. } => true, TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, @@ -1315,10 +1330,12 @@ impl Unifier { || format!("{id}"), |top_level| { let top_level_def = &top_level.definitions.read()[id]; - let TopLevelDef::Class { name, .. } = &*top_level_def.read() else { - unreachable!("expected class definition") + let top_level_def = top_level_def.read(); + let (TopLevelDef::Class { name, .. } | TopLevelDef::Module { name, .. }) = + &*top_level_def + else { + unreachable!("expected module/class definition") }; - name.to_string() }, ) @@ -1446,6 +1463,10 @@ impl Unifier { let ret = self.internal_stringify(signature.ret, obj_to_name, var_to_name, notes); format!("fn[[{params}], {ret}]") } + TypeEnum::TModule { module_id, .. } => { + let name = obj_to_name(module_id.0); + name.to_string() + } } } @@ -1521,7 +1542,9 @@ impl Unifier { // variables, i.e. things like TRecord, TCall should not occur, and we // should be safe to not implement the substitution for those variants. match &*ty { - TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None, + TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } | TypeEnum::TModule { .. } => { + None + } TypeEnum::TVar { id, .. } => mapping.get(id).copied(), TypeEnum::TTuple { ty, is_vararg_ctx } => { let mut new_ty = Cow::from(ty);