diff --git a/nac3artiq/demo/support_class_attr_issue102.py b/nac3artiq/demo/support_class_attr_issue102.py new file mode 100644 index 00000000..1b931444 --- /dev/null +++ b/nac3artiq/demo/support_class_attr_issue102.py @@ -0,0 +1,40 @@ +from min_artiq import * +from numpy import int32 + + +@nac3 +class Demo: + attr1: KernelInvariant[int32] = 2 + attr2: int32 = 4 + attr3: Kernel[int32] + + @kernel + def __init__(self): + self.attr3 = 8 + + +@nac3 +class NAC3Devices: + core: KernelInvariant[Core] + attr4: KernelInvariant[int32] = 16 + + def __init__(self): + self.core = Core() + + @kernel + def run(self): + Demo.attr1 # Supported + # Demo.attr2 # Field not accessible on Kernel + # Demo.attr3 # Only attributes can be accessed in this way + # Demo.attr1 = 2 # Attributes are immutable + + self.attr4 # Attributes can be accessed within class + + obj = Demo() + obj.attr1 # Attributes can be accessed by class objects + + NAC3Devices.attr4 # Attributes accessible for classes without __init__ + + +if __name__ == "__main__": + NAC3Devices().run() diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 7e64dc1d..d3b7000e 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -657,7 +657,7 @@ pub fn attributes_writeback( } if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { attributes.push(name.to_string()); - let index = ctx.get_attr_index(ty, *name); + let (index, _) = ctx.get_attr_index(ty, *name); values.push(( *field_ty, ctx.build_gep_and_load( diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index b9fb6ba8..62597b8c 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -627,12 +627,15 @@ impl InnerResolver { let pyid_to_def = self.pyid_to_def.read(); let constructor_ty = pyid_to_def.get(&py_obj_id).and_then(|def_id| { defs.iter().find_map(|def| { - if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*def.read() { - if object_id == def_id - && constructor.is_some() - && methods.iter().any(|(s, _, _)| s == &"__init__".into()) + if let Some(rear_guard) = def.try_read() { + if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*rear_guard { - return *constructor; + if object_id == def_id + && constructor.is_some() + && methods.iter().any(|(s, _, _)| s == &"__init__".into()) + { + return *constructor; + } } } None @@ -664,7 +667,29 @@ impl InnerResolver { primitives, )? { Ok(s) => s, - Err(e) => return Ok(Err(e)), + Err(e) => { + // Allow access to Class Attributes of Classes without having to initialize Objects + if self.pyid_to_def.read().contains_key(&py_obj_id) { + if let Some(def_id) = self.pyid_to_def.read().get(&py_obj_id).copied() { + let def = defs[def_id.0].read(); + let TopLevelDef::Class { object_id, .. } = &*def else { + // only object is supported, functions are not supported + unreachable!("function type is not supported, should not be queried") + }; + + let ty = TypeEnum::TObj { + obj_id: *object_id, + params: VarMap::new(), + fields: HashMap::new(), + }; + (unifier.add_ty(ty), true) + } else { + return Ok(Err(e)); + } + } else { + return Ok(Err(e)); + } + } }; match (&*unifier.get_ty(extracted_ty), inst_check) { // do the instantiation for these four types diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3c5d0a3d..37e6b863 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -86,19 +86,35 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { get_subst_key(&mut self.unifier, obj, &fun.vars, filter) } - pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize { + /// Checks the field and attributes of classes + /// Returns the index of attr in class fields otherwise returns the attribute value + pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option) { let obj_id = match &*self.unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } => *obj_id, // we cannot have other types, virtual type should be handled by function calls _ => unreachable!(), }; let def = &self.top_level.definitions.read()[obj_id.0]; - let index = if let TopLevelDef::Class { fields, .. } = &*def.read() { - fields.iter().find_position(|x| x.0 == attr).unwrap().0 + let (index, value) = if let TopLevelDef::Class { fields, attributes, .. } = &*def.read() { + if let Some(field_index) = fields.iter().find_position(|x| x.0 == attr) { + (field_index.0, None) + } else { + let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap(); + (attribute_index.0, Some(attribute_index.1 .2.clone())) + } } else { unreachable!() }; - index + (index, value) + } + + pub fn get_attr_index_object(&mut self, ty: Type, attr: StrRef) -> usize { + match &*self.unifier.get_ty(ty) { + TypeEnum::TObj { fields, .. } => { + fields.iter().find_position(|x| *x.0 == attr).unwrap().0 + } + _ => unreachable!(), + } } pub fn gen_symbol_val( @@ -2166,11 +2182,72 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } ExprKind::Attribute { value, attr, .. } => { // note that we would handle class methods directly in calls + + // Change Class attribute access requests to accessing constants from Class Definition + if let Some(c) = value.custom { + if let TypeEnum::TFunc(_) = &*ctx.unifier.get_ty(c) { + let defs = ctx.top_level.definitions.read(); + let result = defs.iter().find_map(|def| { + if let Some(rear_guard) = def.try_read() { + if let TopLevelDef::Class { + constructor: Some(constructor), + attributes, + .. + } = &*rear_guard + { + if *constructor == c { + return attributes.iter().find_map(|f| { + if f.0 == *attr { + // All other checks performed by this point + return Some(f.2.clone()); + } + None + }); + } + } + } + None + }); + match result { + Some(val) => { + let mut modified_expr = expr.clone(); + modified_expr.node = ExprKind::Constant { value: val, kind: None }; + + return generator.gen_expr(ctx, &modified_expr); + } + None => unreachable!("Function Type should not have attributes"), + } + } else if let TypeEnum::TObj { obj_id, fields, params } = &*ctx.unifier.get_ty(c) { + if fields.is_empty() && params.is_empty() { + let defs = ctx.top_level.definitions.read(); + let def = defs[obj_id.0].read(); + match if let TopLevelDef::Class { attributes, .. } = &*def { + attributes.iter().find_map(|f| { + if f.0 == *attr { + return Some(f.2.clone()); + } + None + }) + } else { + None + } { + Some(val) => { + let mut modified_expr = expr.clone(); + modified_expr.node = ExprKind::Constant { value: val, kind: None }; + + return generator.gen_expr(ctx, &modified_expr); + } + None => unreachable!(), + } + } + } + } + match generator.gen_expr(ctx, value)? { Some(ValueEnum::Static(v)) => v.get_field(*attr, ctx).map_or_else( || { let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?; - let index = ctx.get_attr_index(value.custom.unwrap(), *attr); + let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr); Ok(ValueEnum::Dynamic(ctx.build_gep_and_load( v.into_pointer_value(), &[zero, int32.const_int(index as u64, false)], @@ -2180,7 +2257,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( Ok, )?, Some(ValueEnum::Dynamic(v)) => { - let index = ctx.get_attr_index(value.custom.unwrap(), *attr); + let (index, attr_value) = ctx.get_attr_index(value.custom.unwrap(), *attr); + if let Some(val) = attr_value { + // Change to Constant Construct + let mut modified_expr = expr.clone(); + modified_expr.node = ExprKind::Constant { value: val, kind: None }; + + return generator.gen_expr(ctx, &modified_expr); + } ValueEnum::Dynamic(ctx.build_gep_and_load( v.into_pointer_value(), &[zero, int32.const_int(index as u64, false)], @@ -2363,6 +2447,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ExprKind::Attribute { value, attr, .. } => { let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) }; + // Handle Class Method calls let id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) { diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index a670f117..5bae9a94 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -113,7 +113,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( } }, ExprKind::Attribute { value, attr, .. } => { - let index = ctx.get_attr_index(value.custom.unwrap(), *attr); + let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr); let val = if let Some(v) = generator.gen_expr(ctx, value)? { v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? } else { diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index ef6618bc..5531bddc 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -83,6 +83,7 @@ pub fn get_exn_constructor( object_id: DefinitionId(class_id), type_vars: Vec::default(), fields: exception_fields, + attributes: Vec::default(), methods: vec![("__init__".into(), signature, DefinitionId(cons_id))], ancestors: vec![ TypeAnnotation::CustomClass { id: DefinitionId(class_id), params: Vec::default() }, @@ -596,6 +597,7 @@ impl<'a> BuiltinBuilder<'a> { object_id: prim.id(), type_vars: Vec::default(), fields: make_exception_fields(int32, int64, str), + attributes: Vec::default(), methods: Vec::default(), ancestors: vec![], constructor: None, @@ -624,7 +626,8 @@ impl<'a> BuiltinBuilder<'a> { name: prim.name().into(), object_id: prim.id(), type_vars: vec![self.option_tvar.ty], - fields: vec![], + fields: Vec::default(), + attributes: Vec::default(), methods: vec![ Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0), Self::create_method(PrimDef::OptionIsNone, self.is_some_ty.0), @@ -738,6 +741,7 @@ impl<'a> BuiltinBuilder<'a> { object_id: prim.id(), type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty], fields: Vec::default(), + attributes: Vec::default(), methods: vec![ Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0), Self::create_method(PrimDef::NDArrayFill, self.ndarray_fill_ty.0), diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index ff0e82c9..5ba07df5 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1057,7 +1057,14 @@ impl TopLevelComposer { let (keyword_list, core_config) = core_info; let mut class_def = class_def.write(); let TopLevelDef::Class { - object_id, ancestors, fields, methods, resolver, type_vars, .. + object_id, + ancestors, + fields, + attributes, + methods, + resolver, + type_vars, + .. } = &mut *class_def else { unreachable!("here must be toplevel class def"); @@ -1073,10 +1080,14 @@ impl TopLevelComposer { class_body_ast, _class_ancestor_def, class_fields_def, + class_attributes_def, class_methods_def, class_type_vars_def, class_resolver, - ) = (*object_id, *name, bases, body, ancestors, fields, methods, type_vars, resolver); + ) = ( + *object_id, *name, bases, body, ancestors, fields, attributes, methods, type_vars, + resolver, + ); let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.as_ref(); @@ -1285,34 +1296,74 @@ impl TopLevelComposer { .unify(method_dummy_ty, method_type) .map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?; } - ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => { + ast::StmtKind::AnnAssign { target, annotation, value, .. } => { if let ast::ExprKind::Name { id: attr, .. } = &target.node { if defined_fields.insert(attr.to_string()) { let dummy_field_type = unifier.get_dummy_var().ty; - // handle Kernel[T], KernelInvariant[T] - let (annotation, mutable) = match &annotation.node { - ast::ExprKind::Subscript { value, slice, .. } - if matches!( - &value.node, - ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into() - ) => - { - (slice, false) + let annotation = match value { + None => { + // handle Kernel[T], KernelInvariant[T] + let (annotation, mutable) = match &annotation.node { + ast::ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into() + ) => + { + (slice, false) + } + ast::ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into()) + ) => + { + (slice, true) + } + _ if core_config.kernel_ann.is_none() => (annotation, true), + _ => continue, // ignore fields annotated otherwise + }; + class_fields_def.push((*attr, dummy_field_type, mutable)); + annotation } - ast::ExprKind::Subscript { value, slice, .. } - if matches!( - &value.node, - ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into()) - ) => - { - (slice, true) - } - _ if core_config.kernel_ann.is_none() => (annotation, true), - _ => continue, // ignore fields annotated otherwise - }; - class_fields_def.push((*attr, dummy_field_type, mutable)); + // Supporting Class Attributes + Some(boxed_expr) => { + // Class attributes are set as immutable regardless + let (annotation, _) = match &annotation.node { + ast::ExprKind::Subscript { slice, .. } => (slice, false), + _ if core_config.kernel_ann.is_none() => (annotation, false), + _ => continue, + }; + match &**boxed_expr { + ast::Located {location: _, custom: (), node: ast::ExprKind::Constant { value: v, kind: _ }} => { + // Restricting the types allowed to be defined as class attributes + match v { + ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {} + _ => { + return Err(HashSet::from([ + format!( + "unsupported statement in class definition body (at {})", + b.location + ), + ])) + } + } + class_attributes_def.push((*attr, dummy_field_type, v.clone())); + } + _ => { + return Err(HashSet::from([ + format!( + "unsupported statement in class definition body (at {})", + b.location + ), + ])) + } + } + annotation + } + }; let parsed_annotation = parse_ast_to_type_annotation_kinds( class_resolver, temp_def_list, @@ -1384,7 +1435,14 @@ impl TopLevelComposer { type_var_to_concrete_def: &mut HashMap, ) -> Result<(), HashSet> { let TopLevelDef::Class { - object_id, ancestors, fields, methods, resolver, type_vars, .. + object_id, + ancestors, + fields, + attributes, + methods, + resolver, + type_vars, + .. } = class_def else { unreachable!("here must be class def ast") @@ -1393,10 +1451,11 @@ impl TopLevelComposer { _class_id, class_ancestor_def, class_fields_def, + class_attribute_def, class_methods_def, _class_type_vars_def, _class_resolver, - ) = (*object_id, ancestors, fields, methods, type_vars, resolver); + ) = (*object_id, ancestors, fields, attributes, methods, type_vars, resolver); // since when this function is called, the ancestors of the direct parent // are supposed to be already handled, so we only need to deal with the direct parent @@ -1407,7 +1466,7 @@ impl TopLevelComposer { let base = temp_def_list.get(id.0).unwrap(); let base = base.read(); - let TopLevelDef::Class { methods, fields, .. } = &*base else { + let TopLevelDef::Class { methods, fields, attributes, .. } = &*base else { unreachable!("must be top level class def") }; @@ -1449,7 +1508,7 @@ impl TopLevelComposer { } } // use the new_child_methods to replace all the elements in `class_methods_def` - class_methods_def.drain(..); + class_methods_def.clear(); class_methods_def.extend(new_child_methods); // handle class fields @@ -1459,7 +1518,9 @@ impl TopLevelComposer { let to_be_added = (*anc_field_name, *anc_field_ty, *mutable); // find if there is a fields with the same name in the child class for (class_field_name, ..) in &*class_fields_def { - if class_field_name == anc_field_name { + if class_field_name == anc_field_name + || attributes.iter().any(|f| f.0 == *class_field_name) + { return Err(HashSet::from([format!( "field `{class_field_name}` has already declared in the ancestor classes" )])); @@ -1467,14 +1528,33 @@ impl TopLevelComposer { } new_child_fields.push(to_be_added); } + + // handle class attributes + let mut new_child_attributes: Vec<(StrRef, Type, ast::Constant)> = Vec::new(); + for (anc_attr_name, anc_attr_ty, attr_value) in attributes { + let to_be_added = (*anc_attr_name, *anc_attr_ty, attr_value.clone()); + // find if there is a attribute with the same name in the child class + for (class_attr_name, ..) in &*class_attribute_def { + if class_attr_name == anc_attr_name + || fields.iter().any(|f| f.0 == *class_attr_name) + { + return Err(HashSet::from([format!( + "attribute `{class_attr_name}` has already declared in the ancestor classes" + )])); + } + } + new_child_attributes.push(to_be_added); + } + for (class_field_name, class_field_ty, mutable) in &*class_fields_def { if !is_override.contains(class_field_name) { new_child_fields.push((*class_field_name, *class_field_ty, *mutable)); } } - class_fields_def.drain(..); + class_fields_def.clear(); class_fields_def.extend(new_child_fields); - + class_attribute_def.clear(); + class_attribute_def.extend(new_child_attributes); Ok(()) } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 2fcc24c4..73812505 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -474,6 +474,7 @@ impl TopLevelComposer { object_id: obj_id, type_vars: Vec::default(), fields: Vec::default(), + attributes: Vec::default(), methods: Vec::default(), ancestors: Vec::default(), constructor, diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index c08f287f..344c75bb 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -103,6 +103,10 @@ pub enum TopLevelDef { /// /// Name and type is mutable. fields: Vec<(StrRef, Type, bool)>, + /// Class Attributes. + /// + /// Name, type, value. + attributes: Vec<(StrRef, Type, ast::Constant)>, /// Class methods, pointing to the corresponding function definition. methods: Vec<(StrRef, Type, DefinitionId)>, /// Ancestor classes, including itself. diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 7b5c4056..f598badf 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -470,6 +470,7 @@ pub fn get_type_from_type_annotation_kinds( } result }; + // Class Attributes keep a copy with Class Definition and are not added to objects let mut tobj_fields = methods .iter() .map(|(name, ty, _)| { diff --git a/nac3core/src/typecheck/type_error.rs b/nac3core/src/typecheck/type_error.rs index abf61a56..e6d7d73a 100644 --- a/nac3core/src/typecheck/type_error.rs +++ b/nac3core/src/typecheck/type_error.rs @@ -34,6 +34,7 @@ pub enum TypeErrorKind { }, RequiresTypeAnn, PolymorphicFunctionPointer, + NoSuchAttribute(RecordKey, Type), } #[derive(Debug, Clone)] @@ -156,6 +157,10 @@ impl<'a> Display for DisplayTypeError<'a> { let t = self.unifier.stringify_with_notes(*t, &mut notes); write!(f, "`{t}::{name}` field/method does not exist") } + NoSuchAttribute(name, t) => { + let t = self.unifier.stringify_with_notes(*t, &mut notes); + write!(f, "`{t}::{name}` is not a class attribute") + } TupleIndexOutOfBounds { index, len } => { write!( f, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 1925174a..a3366c9b 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -6,6 +6,7 @@ use std::{cell::RefCell, sync::Arc}; use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; +use crate::toplevel::TopLevelDef; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ @@ -1441,6 +1442,24 @@ impl<'a> Inferencer<'a> { Ok(self.unifier.add_ty(TypeEnum::TTuple { ty })) } + /// Checks for non-class attributes + fn infer_general_attribute( + &mut self, + value: &ast::Expr>, + attr: StrRef, + ctx: ExprContext, + ) -> InferenceResult { + let attr_ty = self.unifier.get_dummy_var().ty; + let fields = once(( + attr.into(), + RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)), + )) + .collect(); + let record = self.unifier.add_record(fields); + self.constrain(value.custom.unwrap(), record, &value.location)?; + Ok(attr_ty) + } + fn infer_attribute( &mut self, value: &ast::Expr>, @@ -1448,31 +1467,72 @@ impl<'a> Inferencer<'a> { ctx: ExprContext, ) -> InferenceResult { let ty = value.custom.unwrap(); - if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) { + if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) { // just a fast path match (fields.get(&attr), ctx == ExprContext::Store) { (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), (Some((_, false)), true) => { report_error(&format!("Field `{attr}` is immutable"), value.location) } - (None, _) => { - let t = self.unifier.stringify(ty); - report_error( - &format!("`{t}::{attr}` field/method does not exist"), - value.location, - ) + (None, mutable) => { + // Check whether it is a class attribute + let defs = self.top_level.definitions.read(); + let result = { + if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() { + attributes.iter().find_map(|f| { + if f.0 == attr { + return Some(f.1); + } + None + }) + } else { + None + } + }; + match result { + Some(res) if !mutable => Ok(res), + Some(_) => report_error( + &format!("Class Attribute `{attr}` is immutable"), + value.location, + ), + None => { + let t = self.unifier.stringify(ty); + report_error( + &format!("`{t}::{attr}` field/method does not exist"), + value.location, + ) + } + } } } + } else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) { + // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 + let result = { + self.top_level.definitions.read().iter().find_map(|def| { + if let Some(rear_guard) = def.try_read() { + if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard { + if name.to_string() == self.unifier.stringify(sign.ret) { + return attributes.iter().find_map(|f| { + if f.0 == attr { + return Some(f.clone().1); + } + None + }); + } + } + } + None + }) + }; + match result { + Some(f) if ctx != ExprContext::Store => Ok(f), + Some(_) => { + report_error(&format!("Class Attribute `{attr}` is immutable"), value.location) + } + None => self.infer_general_attribute(value, attr, ctx), + } } else { - let attr_ty = self.unifier.get_dummy_var().ty; - let fields = once(( - attr.into(), - RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)), - )) - .collect(); - let record = self.unifier.add_record(fields); - self.constrain(value.custom.unwrap(), record, &value.location)?; - Ok(attr_ty) + self.infer_general_attribute(value, attr, ctx) } } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 700ce509..1e3b75f2 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -289,6 +289,7 @@ impl TestEnvironment { object_id: DefinitionId(i), type_vars: Vec::default(), fields: Vec::default(), + attributes: Vec::default(), methods: Vec::default(), ancestors: Vec::default(), resolver: None, @@ -331,6 +332,7 @@ impl TestEnvironment { object_id: DefinitionId(defs + 1), type_vars: vec![tvar.ty], fields: [("a".into(), tvar.ty, true)].into(), + attributes: Vec::default(), methods: Vec::default(), ancestors: Vec::default(), resolver: None, @@ -365,6 +367,7 @@ impl TestEnvironment { object_id: DefinitionId(defs + 2), type_vars: Vec::default(), fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(), + attributes: Vec::default(), methods: Vec::default(), ancestors: Vec::default(), resolver: None, @@ -393,6 +396,7 @@ impl TestEnvironment { object_id: DefinitionId(defs + 3), type_vars: Vec::default(), fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(), + attributes: Vec::default(), methods: Vec::default(), ancestors: Vec::default(), resolver: None,