diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 7dcecff8c..994c50484 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -657,7 +657,7 @@ pub fn attributes_writeback( } if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { attributes.push(name.to_string()); - let index = ctx.get_attr_index(ty, *name); + let (index, _) = ctx.get_attr_index(ty, *name); values.push(( *field_ty, ctx.build_gep_and_load( diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index a8430e29a..2d1880586 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -86,19 +86,35 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { get_subst_key(&mut self.unifier, obj, &fun.vars, filter) } - pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize { + /// Checks the field and attributes of classes + /// Returns the index of attr in class fields otherwise returns the attribute value + pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option) { let obj_id = match &*self.unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } => *obj_id, // we cannot have other types, virtual type should be handled by function calls _ => unreachable!(), }; let def = &self.top_level.definitions.read()[obj_id.0]; - let index = if let TopLevelDef::Class { fields, .. } = &*def.read() { - fields.iter().find_position(|x| x.0 == attr).unwrap().0 + let (index, value) = if let TopLevelDef::Class { fields, attributes, .. } = &*def.read() { + if let Some(field_index) = fields.iter().find_position(|x| x.0 == attr) { + (field_index.0, None) + } else { + let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap(); + (attribute_index.0, Some(attribute_index.1 .2.clone())) + } } else { unreachable!() }; - index + (index, value) + } + + pub fn get_attr_index_object(&mut self, ty: Type, attr: StrRef) -> usize { + match &*self.unifier.get_ty(ty) { + TypeEnum::TObj { fields, .. } => { + fields.iter().find_position(|x| *x.0 == attr).unwrap().0 + } + _ => unreachable!(), + } } pub fn gen_symbol_val( @@ -2166,11 +2182,72 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( } ExprKind::Attribute { value, attr, .. } => { // note that we would handle class methods directly in calls + + // Change Class attribute access requests to accessing constants from Class Definition + if let Some(c) = value.custom { + if let TypeEnum::TFunc(_) = &*ctx.unifier.get_ty(c) { + let defs = ctx.top_level.definitions.read(); + let result = defs.iter().find_map(|def| { + if let Some(rear_guard) = def.try_read() { + if let TopLevelDef::Class { + constructor: Some(constructor), + attributes, + .. + } = &*rear_guard + { + if *constructor == c { + return attributes.iter().find_map(|f| { + if f.0 == *attr { + // All other checks performed by this point + return Some(f.2.clone()); + } + None + }); + } + } + } + None + }); + match result { + Some(val) => { + let mut modified_expr = expr.clone(); + modified_expr.node = ExprKind::Constant { value: val, kind: None }; + + return generator.gen_expr(ctx, &modified_expr); + } + None => unreachable!("Function Type should not have attributes"), + } + } else if let TypeEnum::TObj { obj_id, fields, params } = &*ctx.unifier.get_ty(c) { + if fields.is_empty() && params.is_empty() { + let defs = ctx.top_level.definitions.read(); + let def = defs[obj_id.0].read(); + match if let TopLevelDef::Class { attributes, .. } = &*def { + attributes.iter().find_map(|f| { + if f.0 == *attr { + return Some(f.2.clone()); + } + None + }) + } else { + None + } { + Some(val) => { + let mut modified_expr = expr.clone(); + modified_expr.node = ExprKind::Constant { value: val, kind: None }; + + return generator.gen_expr(ctx, &modified_expr); + } + None => unreachable!(), + } + } + } + } + match generator.gen_expr(ctx, value)? { Some(ValueEnum::Static(v)) => v.get_field(*attr, ctx).map_or_else( || { let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?; - let index = ctx.get_attr_index(value.custom.unwrap(), *attr); + let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr); Ok(ValueEnum::Dynamic(ctx.build_gep_and_load( v.into_pointer_value(), &[zero, int32.const_int(index as u64, false)], @@ -2180,7 +2257,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( Ok, )?, Some(ValueEnum::Dynamic(v)) => { - let index = ctx.get_attr_index(value.custom.unwrap(), *attr); + let (index, attr_value) = ctx.get_attr_index(value.custom.unwrap(), *attr); + if let Some(val) = attr_value { + // Change to Constant Construct + let mut modified_expr = expr.clone(); + modified_expr.node = ExprKind::Constant { value: val, kind: None }; + + return generator.gen_expr(ctx, &modified_expr); + } ValueEnum::Dynamic(ctx.build_gep_and_load( v.into_pointer_value(), &[zero, int32.const_int(index as u64, false)], @@ -2363,6 +2447,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ExprKind::Attribute { value, attr, .. } => { let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) }; + // Handle Class Method calls let id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) { diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index a670f1172..5bae9a94c 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -113,7 +113,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( } }, ExprKind::Attribute { value, attr, .. } => { - let index = ctx.get_attr_index(value.custom.unwrap(), *attr); + let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr); let val = if let Some(v) = generator.gen_expr(ctx, value)? { v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? } else { diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index ff0e82c9e..5ba07df5b 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -1057,7 +1057,14 @@ impl TopLevelComposer { let (keyword_list, core_config) = core_info; let mut class_def = class_def.write(); let TopLevelDef::Class { - object_id, ancestors, fields, methods, resolver, type_vars, .. + object_id, + ancestors, + fields, + attributes, + methods, + resolver, + type_vars, + .. } = &mut *class_def else { unreachable!("here must be toplevel class def"); @@ -1073,10 +1080,14 @@ impl TopLevelComposer { class_body_ast, _class_ancestor_def, class_fields_def, + class_attributes_def, class_methods_def, class_type_vars_def, class_resolver, - ) = (*object_id, *name, bases, body, ancestors, fields, methods, type_vars, resolver); + ) = ( + *object_id, *name, bases, body, ancestors, fields, attributes, methods, type_vars, + resolver, + ); let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.as_ref(); @@ -1285,34 +1296,74 @@ impl TopLevelComposer { .unify(method_dummy_ty, method_type) .map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?; } - ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => { + ast::StmtKind::AnnAssign { target, annotation, value, .. } => { if let ast::ExprKind::Name { id: attr, .. } = &target.node { if defined_fields.insert(attr.to_string()) { let dummy_field_type = unifier.get_dummy_var().ty; - // handle Kernel[T], KernelInvariant[T] - let (annotation, mutable) = match &annotation.node { - ast::ExprKind::Subscript { value, slice, .. } - if matches!( - &value.node, - ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into() - ) => - { - (slice, false) + let annotation = match value { + None => { + // handle Kernel[T], KernelInvariant[T] + let (annotation, mutable) = match &annotation.node { + ast::ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into() + ) => + { + (slice, false) + } + ast::ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into()) + ) => + { + (slice, true) + } + _ if core_config.kernel_ann.is_none() => (annotation, true), + _ => continue, // ignore fields annotated otherwise + }; + class_fields_def.push((*attr, dummy_field_type, mutable)); + annotation } - ast::ExprKind::Subscript { value, slice, .. } - if matches!( - &value.node, - ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into()) - ) => - { - (slice, true) - } - _ if core_config.kernel_ann.is_none() => (annotation, true), - _ => continue, // ignore fields annotated otherwise - }; - class_fields_def.push((*attr, dummy_field_type, mutable)); + // Supporting Class Attributes + Some(boxed_expr) => { + // Class attributes are set as immutable regardless + let (annotation, _) = match &annotation.node { + ast::ExprKind::Subscript { slice, .. } => (slice, false), + _ if core_config.kernel_ann.is_none() => (annotation, false), + _ => continue, + }; + match &**boxed_expr { + ast::Located {location: _, custom: (), node: ast::ExprKind::Constant { value: v, kind: _ }} => { + // Restricting the types allowed to be defined as class attributes + match v { + ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {} + _ => { + return Err(HashSet::from([ + format!( + "unsupported statement in class definition body (at {})", + b.location + ), + ])) + } + } + class_attributes_def.push((*attr, dummy_field_type, v.clone())); + } + _ => { + return Err(HashSet::from([ + format!( + "unsupported statement in class definition body (at {})", + b.location + ), + ])) + } + } + annotation + } + }; let parsed_annotation = parse_ast_to_type_annotation_kinds( class_resolver, temp_def_list, @@ -1384,7 +1435,14 @@ impl TopLevelComposer { type_var_to_concrete_def: &mut HashMap, ) -> Result<(), HashSet> { let TopLevelDef::Class { - object_id, ancestors, fields, methods, resolver, type_vars, .. + object_id, + ancestors, + fields, + attributes, + methods, + resolver, + type_vars, + .. } = class_def else { unreachable!("here must be class def ast") @@ -1393,10 +1451,11 @@ impl TopLevelComposer { _class_id, class_ancestor_def, class_fields_def, + class_attribute_def, class_methods_def, _class_type_vars_def, _class_resolver, - ) = (*object_id, ancestors, fields, methods, type_vars, resolver); + ) = (*object_id, ancestors, fields, attributes, methods, type_vars, resolver); // since when this function is called, the ancestors of the direct parent // are supposed to be already handled, so we only need to deal with the direct parent @@ -1407,7 +1466,7 @@ impl TopLevelComposer { let base = temp_def_list.get(id.0).unwrap(); let base = base.read(); - let TopLevelDef::Class { methods, fields, .. } = &*base else { + let TopLevelDef::Class { methods, fields, attributes, .. } = &*base else { unreachable!("must be top level class def") }; @@ -1449,7 +1508,7 @@ impl TopLevelComposer { } } // use the new_child_methods to replace all the elements in `class_methods_def` - class_methods_def.drain(..); + class_methods_def.clear(); class_methods_def.extend(new_child_methods); // handle class fields @@ -1459,7 +1518,9 @@ impl TopLevelComposer { let to_be_added = (*anc_field_name, *anc_field_ty, *mutable); // find if there is a fields with the same name in the child class for (class_field_name, ..) in &*class_fields_def { - if class_field_name == anc_field_name { + if class_field_name == anc_field_name + || attributes.iter().any(|f| f.0 == *class_field_name) + { return Err(HashSet::from([format!( "field `{class_field_name}` has already declared in the ancestor classes" )])); @@ -1467,14 +1528,33 @@ impl TopLevelComposer { } new_child_fields.push(to_be_added); } + + // handle class attributes + let mut new_child_attributes: Vec<(StrRef, Type, ast::Constant)> = Vec::new(); + for (anc_attr_name, anc_attr_ty, attr_value) in attributes { + let to_be_added = (*anc_attr_name, *anc_attr_ty, attr_value.clone()); + // find if there is a attribute with the same name in the child class + for (class_attr_name, ..) in &*class_attribute_def { + if class_attr_name == anc_attr_name + || fields.iter().any(|f| f.0 == *class_attr_name) + { + return Err(HashSet::from([format!( + "attribute `{class_attr_name}` has already declared in the ancestor classes" + )])); + } + } + new_child_attributes.push(to_be_added); + } + for (class_field_name, class_field_ty, mutable) in &*class_fields_def { if !is_override.contains(class_field_name) { new_child_fields.push((*class_field_name, *class_field_ty, *mutable)); } } - class_fields_def.drain(..); + class_fields_def.clear(); class_fields_def.extend(new_child_fields); - + class_attribute_def.clear(); + class_attribute_def.extend(new_child_attributes); Ok(()) } diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 7b5c40567..f598badfc 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -470,6 +470,7 @@ pub fn get_type_from_type_annotation_kinds( } result }; + // Class Attributes keep a copy with Class Definition and are not added to objects let mut tobj_fields = methods .iter() .map(|(name, ty, _)| { diff --git a/nac3core/src/typecheck/type_error.rs b/nac3core/src/typecheck/type_error.rs index abf61a56f..e6d7d73a3 100644 --- a/nac3core/src/typecheck/type_error.rs +++ b/nac3core/src/typecheck/type_error.rs @@ -34,6 +34,7 @@ pub enum TypeErrorKind { }, RequiresTypeAnn, PolymorphicFunctionPointer, + NoSuchAttribute(RecordKey, Type), } #[derive(Debug, Clone)] @@ -156,6 +157,10 @@ impl<'a> Display for DisplayTypeError<'a> { let t = self.unifier.stringify_with_notes(*t, &mut notes); write!(f, "`{t}::{name}` field/method does not exist") } + NoSuchAttribute(name, t) => { + let t = self.unifier.stringify_with_notes(*t, &mut notes); + write!(f, "`{t}::{name}` is not a class attribute") + } TupleIndexOutOfBounds { index, len } => { write!( f, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 1925174a8..a3366c9b8 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -6,6 +6,7 @@ use std::{cell::RefCell, sync::Arc}; use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; +use crate::toplevel::TopLevelDef; use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ @@ -1441,6 +1442,24 @@ impl<'a> Inferencer<'a> { Ok(self.unifier.add_ty(TypeEnum::TTuple { ty })) } + /// Checks for non-class attributes + fn infer_general_attribute( + &mut self, + value: &ast::Expr>, + attr: StrRef, + ctx: ExprContext, + ) -> InferenceResult { + let attr_ty = self.unifier.get_dummy_var().ty; + let fields = once(( + attr.into(), + RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)), + )) + .collect(); + let record = self.unifier.add_record(fields); + self.constrain(value.custom.unwrap(), record, &value.location)?; + Ok(attr_ty) + } + fn infer_attribute( &mut self, value: &ast::Expr>, @@ -1448,31 +1467,72 @@ impl<'a> Inferencer<'a> { ctx: ExprContext, ) -> InferenceResult { let ty = value.custom.unwrap(); - if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) { + if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) { // just a fast path match (fields.get(&attr), ctx == ExprContext::Store) { (Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty), (Some((_, false)), true) => { report_error(&format!("Field `{attr}` is immutable"), value.location) } - (None, _) => { - let t = self.unifier.stringify(ty); - report_error( - &format!("`{t}::{attr}` field/method does not exist"), - value.location, - ) + (None, mutable) => { + // Check whether it is a class attribute + let defs = self.top_level.definitions.read(); + let result = { + if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() { + attributes.iter().find_map(|f| { + if f.0 == attr { + return Some(f.1); + } + None + }) + } else { + None + } + }; + match result { + Some(res) if !mutable => Ok(res), + Some(_) => report_error( + &format!("Class Attribute `{attr}` is immutable"), + value.location, + ), + None => { + let t = self.unifier.stringify(ty); + report_error( + &format!("`{t}::{attr}` field/method does not exist"), + value.location, + ) + } + } } } + } else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) { + // Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1 + let result = { + self.top_level.definitions.read().iter().find_map(|def| { + if let Some(rear_guard) = def.try_read() { + if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard { + if name.to_string() == self.unifier.stringify(sign.ret) { + return attributes.iter().find_map(|f| { + if f.0 == attr { + return Some(f.clone().1); + } + None + }); + } + } + } + None + }) + }; + match result { + Some(f) if ctx != ExprContext::Store => Ok(f), + Some(_) => { + report_error(&format!("Class Attribute `{attr}` is immutable"), value.location) + } + None => self.infer_general_attribute(value, attr, ctx), + } } else { - let attr_ty = self.unifier.get_dummy_var().ty; - let fields = once(( - attr.into(), - RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)), - )) - .collect(); - let record = self.unifier.add_record(fields); - self.constrain(value.custom.unwrap(), record, &value.location)?; - Ok(attr_ty) + self.infer_general_attribute(value, attr, ctx) } }