1
0
forked from M-Labs/nac3

core: add support for class attributes

This commit is contained in:
abdul124 2024-06-19 17:19:55 +08:00 committed by sb10q
parent 7fe2c3496c
commit 134af79fd6
7 changed files with 286 additions and 55 deletions

View File

@ -657,7 +657,7 @@ pub fn attributes_writeback(
} }
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
attributes.push(name.to_string()); attributes.push(name.to_string());
let index = ctx.get_attr_index(ty, *name); let (index, _) = ctx.get_attr_index(ty, *name);
values.push(( values.push((
*field_ty, *field_ty,
ctx.build_gep_and_load( ctx.build_gep_and_load(

View File

@ -86,19 +86,35 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
get_subst_key(&mut self.unifier, obj, &fun.vars, filter) get_subst_key(&mut self.unifier, obj, &fun.vars, filter)
} }
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize { /// Checks the field and attributes of classes
/// Returns the index of attr in class fields otherwise returns the attribute value
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option<Constant>) {
let obj_id = match &*self.unifier.get_ty(ty) { let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id, TypeEnum::TObj { obj_id, .. } => *obj_id,
// we cannot have other types, virtual type should be handled by function calls // we cannot have other types, virtual type should be handled by function calls
_ => unreachable!(), _ => unreachable!(),
}; };
let def = &self.top_level.definitions.read()[obj_id.0]; let def = &self.top_level.definitions.read()[obj_id.0];
let index = if let TopLevelDef::Class { fields, .. } = &*def.read() { let (index, value) = if let TopLevelDef::Class { fields, attributes, .. } = &*def.read() {
fields.iter().find_position(|x| x.0 == attr).unwrap().0 if let Some(field_index) = fields.iter().find_position(|x| x.0 == attr) {
(field_index.0, None)
} else {
let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap();
(attribute_index.0, Some(attribute_index.1 .2.clone()))
}
} else { } else {
unreachable!() unreachable!()
}; };
index (index, value)
}
pub fn get_attr_index_object(&mut self, ty: Type, attr: StrRef) -> usize {
match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { fields, .. } => {
fields.iter().find_position(|x| *x.0 == attr).unwrap().0
}
_ => unreachable!(),
}
} }
pub fn gen_symbol_val<G: CodeGenerator + ?Sized>( pub fn gen_symbol_val<G: CodeGenerator + ?Sized>(
@ -2166,11 +2182,72 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
} }
ExprKind::Attribute { value, attr, .. } => { ExprKind::Attribute { value, attr, .. } => {
// note that we would handle class methods directly in calls // note that we would handle class methods directly in calls
// Change Class attribute access requests to accessing constants from Class Definition
if let Some(c) = value.custom {
if let TypeEnum::TFunc(_) = &*ctx.unifier.get_ty(c) {
let defs = ctx.top_level.definitions.read();
let result = defs.iter().find_map(|def| {
if let Some(rear_guard) = def.try_read() {
if let TopLevelDef::Class {
constructor: Some(constructor),
attributes,
..
} = &*rear_guard
{
if *constructor == c {
return attributes.iter().find_map(|f| {
if f.0 == *attr {
// All other checks performed by this point
return Some(f.2.clone());
}
None
});
}
}
}
None
});
match result {
Some(val) => {
let mut modified_expr = expr.clone();
modified_expr.node = ExprKind::Constant { value: val, kind: None };
return generator.gen_expr(ctx, &modified_expr);
}
None => unreachable!("Function Type should not have attributes"),
}
} else if let TypeEnum::TObj { obj_id, fields, params } = &*ctx.unifier.get_ty(c) {
if fields.is_empty() && params.is_empty() {
let defs = ctx.top_level.definitions.read();
let def = defs[obj_id.0].read();
match if let TopLevelDef::Class { attributes, .. } = &*def {
attributes.iter().find_map(|f| {
if f.0 == *attr {
return Some(f.2.clone());
}
None
})
} else {
None
} {
Some(val) => {
let mut modified_expr = expr.clone();
modified_expr.node = ExprKind::Constant { value: val, kind: None };
return generator.gen_expr(ctx, &modified_expr);
}
None => unreachable!(),
}
}
}
}
match generator.gen_expr(ctx, value)? { match generator.gen_expr(ctx, value)? {
Some(ValueEnum::Static(v)) => v.get_field(*attr, ctx).map_or_else( Some(ValueEnum::Static(v)) => v.get_field(*attr, ctx).map_or_else(
|| { || {
let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?; let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
let index = ctx.get_attr_index(value.custom.unwrap(), *attr); let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr);
Ok(ValueEnum::Dynamic(ctx.build_gep_and_load( Ok(ValueEnum::Dynamic(ctx.build_gep_and_load(
v.into_pointer_value(), v.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)], &[zero, int32.const_int(index as u64, false)],
@ -2180,7 +2257,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
Ok, Ok,
)?, )?,
Some(ValueEnum::Dynamic(v)) => { Some(ValueEnum::Dynamic(v)) => {
let index = ctx.get_attr_index(value.custom.unwrap(), *attr); let (index, attr_value) = ctx.get_attr_index(value.custom.unwrap(), *attr);
if let Some(val) = attr_value {
// Change to Constant Construct
let mut modified_expr = expr.clone();
modified_expr.node = ExprKind::Constant { value: val, kind: None };
return generator.gen_expr(ctx, &modified_expr);
}
ValueEnum::Dynamic(ctx.build_gep_and_load( ValueEnum::Dynamic(ctx.build_gep_and_load(
v.into_pointer_value(), v.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)], &[zero, int32.const_int(index as u64, false)],
@ -2363,6 +2447,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
ExprKind::Attribute { value, attr, .. } => { ExprKind::Attribute { value, attr, .. } => {
let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) }; let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) };
// Handle Class Method calls
let id = if let TypeEnum::TObj { obj_id, .. } = let id = if let TypeEnum::TObj { obj_id, .. } =
&*ctx.unifier.get_ty(value.custom.unwrap()) &*ctx.unifier.get_ty(value.custom.unwrap())
{ {

View File

@ -113,7 +113,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
} }
}, },
ExprKind::Attribute { value, attr, .. } => { ExprKind::Attribute { value, attr, .. } => {
let index = ctx.get_attr_index(value.custom.unwrap(), *attr); let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr);
let val = if let Some(v) = generator.gen_expr(ctx, value)? { let val = if let Some(v) = generator.gen_expr(ctx, value)? {
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
} else { } else {

View File

@ -1057,7 +1057,14 @@ impl TopLevelComposer {
let (keyword_list, core_config) = core_info; let (keyword_list, core_config) = core_info;
let mut class_def = class_def.write(); let mut class_def = class_def.write();
let TopLevelDef::Class { let TopLevelDef::Class {
object_id, ancestors, fields, methods, resolver, type_vars, .. object_id,
ancestors,
fields,
attributes,
methods,
resolver,
type_vars,
..
} = &mut *class_def } = &mut *class_def
else { else {
unreachable!("here must be toplevel class def"); unreachable!("here must be toplevel class def");
@ -1073,10 +1080,14 @@ impl TopLevelComposer {
class_body_ast, class_body_ast,
_class_ancestor_def, _class_ancestor_def,
class_fields_def, class_fields_def,
class_attributes_def,
class_methods_def, class_methods_def,
class_type_vars_def, class_type_vars_def,
class_resolver, class_resolver,
) = (*object_id, *name, bases, body, ancestors, fields, methods, type_vars, resolver); ) = (
*object_id, *name, bases, body, ancestors, fields, attributes, methods, type_vars,
resolver,
);
let class_resolver = class_resolver.as_ref().unwrap(); let class_resolver = class_resolver.as_ref().unwrap();
let class_resolver = class_resolver.as_ref(); let class_resolver = class_resolver.as_ref();
@ -1285,11 +1296,13 @@ impl TopLevelComposer {
.unify(method_dummy_ty, method_type) .unify(method_dummy_ty, method_type)
.map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?; .map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?;
} }
ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => { ast::StmtKind::AnnAssign { target, annotation, value, .. } => {
if let ast::ExprKind::Name { id: attr, .. } = &target.node { if let ast::ExprKind::Name { id: attr, .. } = &target.node {
if defined_fields.insert(attr.to_string()) { if defined_fields.insert(attr.to_string()) {
let dummy_field_type = unifier.get_dummy_var().ty; let dummy_field_type = unifier.get_dummy_var().ty;
let annotation = match value {
None => {
// handle Kernel[T], KernelInvariant[T] // handle Kernel[T], KernelInvariant[T]
let (annotation, mutable) = match &annotation.node { let (annotation, mutable) = match &annotation.node {
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
@ -1312,7 +1325,45 @@ impl TopLevelComposer {
_ => continue, // ignore fields annotated otherwise _ => continue, // ignore fields annotated otherwise
}; };
class_fields_def.push((*attr, dummy_field_type, mutable)); class_fields_def.push((*attr, dummy_field_type, mutable));
annotation
}
// Supporting Class Attributes
Some(boxed_expr) => {
// Class attributes are set as immutable regardless
let (annotation, _) = match &annotation.node {
ast::ExprKind::Subscript { slice, .. } => (slice, false),
_ if core_config.kernel_ann.is_none() => (annotation, false),
_ => continue,
};
match &**boxed_expr {
ast::Located {location: _, custom: (), node: ast::ExprKind::Constant { value: v, kind: _ }} => {
// Restricting the types allowed to be defined as class attributes
match v {
ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {}
_ => {
return Err(HashSet::from([
format!(
"unsupported statement in class definition body (at {})",
b.location
),
]))
}
}
class_attributes_def.push((*attr, dummy_field_type, v.clone()));
}
_ => {
return Err(HashSet::from([
format!(
"unsupported statement in class definition body (at {})",
b.location
),
]))
}
}
annotation
}
};
let parsed_annotation = parse_ast_to_type_annotation_kinds( let parsed_annotation = parse_ast_to_type_annotation_kinds(
class_resolver, class_resolver,
temp_def_list, temp_def_list,
@ -1384,7 +1435,14 @@ impl TopLevelComposer {
type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>, type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>,
) -> Result<(), HashSet<String>> { ) -> Result<(), HashSet<String>> {
let TopLevelDef::Class { let TopLevelDef::Class {
object_id, ancestors, fields, methods, resolver, type_vars, .. object_id,
ancestors,
fields,
attributes,
methods,
resolver,
type_vars,
..
} = class_def } = class_def
else { else {
unreachable!("here must be class def ast") unreachable!("here must be class def ast")
@ -1393,10 +1451,11 @@ impl TopLevelComposer {
_class_id, _class_id,
class_ancestor_def, class_ancestor_def,
class_fields_def, class_fields_def,
class_attribute_def,
class_methods_def, class_methods_def,
_class_type_vars_def, _class_type_vars_def,
_class_resolver, _class_resolver,
) = (*object_id, ancestors, fields, methods, type_vars, resolver); ) = (*object_id, ancestors, fields, attributes, methods, type_vars, resolver);
// since when this function is called, the ancestors of the direct parent // since when this function is called, the ancestors of the direct parent
// are supposed to be already handled, so we only need to deal with the direct parent // are supposed to be already handled, so we only need to deal with the direct parent
@ -1407,7 +1466,7 @@ impl TopLevelComposer {
let base = temp_def_list.get(id.0).unwrap(); let base = temp_def_list.get(id.0).unwrap();
let base = base.read(); let base = base.read();
let TopLevelDef::Class { methods, fields, .. } = &*base else { let TopLevelDef::Class { methods, fields, attributes, .. } = &*base else {
unreachable!("must be top level class def") unreachable!("must be top level class def")
}; };
@ -1449,7 +1508,7 @@ impl TopLevelComposer {
} }
} }
// use the new_child_methods to replace all the elements in `class_methods_def` // use the new_child_methods to replace all the elements in `class_methods_def`
class_methods_def.drain(..); class_methods_def.clear();
class_methods_def.extend(new_child_methods); class_methods_def.extend(new_child_methods);
// handle class fields // handle class fields
@ -1459,7 +1518,9 @@ impl TopLevelComposer {
let to_be_added = (*anc_field_name, *anc_field_ty, *mutable); let to_be_added = (*anc_field_name, *anc_field_ty, *mutable);
// find if there is a fields with the same name in the child class // find if there is a fields with the same name in the child class
for (class_field_name, ..) in &*class_fields_def { for (class_field_name, ..) in &*class_fields_def {
if class_field_name == anc_field_name { if class_field_name == anc_field_name
|| attributes.iter().any(|f| f.0 == *class_field_name)
{
return Err(HashSet::from([format!( return Err(HashSet::from([format!(
"field `{class_field_name}` has already declared in the ancestor classes" "field `{class_field_name}` has already declared in the ancestor classes"
)])); )]));
@ -1467,14 +1528,33 @@ impl TopLevelComposer {
} }
new_child_fields.push(to_be_added); new_child_fields.push(to_be_added);
} }
// handle class attributes
let mut new_child_attributes: Vec<(StrRef, Type, ast::Constant)> = Vec::new();
for (anc_attr_name, anc_attr_ty, attr_value) in attributes {
let to_be_added = (*anc_attr_name, *anc_attr_ty, attr_value.clone());
// find if there is a attribute with the same name in the child class
for (class_attr_name, ..) in &*class_attribute_def {
if class_attr_name == anc_attr_name
|| fields.iter().any(|f| f.0 == *class_attr_name)
{
return Err(HashSet::from([format!(
"attribute `{class_attr_name}` has already declared in the ancestor classes"
)]));
}
}
new_child_attributes.push(to_be_added);
}
for (class_field_name, class_field_ty, mutable) in &*class_fields_def { for (class_field_name, class_field_ty, mutable) in &*class_fields_def {
if !is_override.contains(class_field_name) { if !is_override.contains(class_field_name) {
new_child_fields.push((*class_field_name, *class_field_ty, *mutable)); new_child_fields.push((*class_field_name, *class_field_ty, *mutable));
} }
} }
class_fields_def.drain(..); class_fields_def.clear();
class_fields_def.extend(new_child_fields); class_fields_def.extend(new_child_fields);
class_attribute_def.clear();
class_attribute_def.extend(new_child_attributes);
Ok(()) Ok(())
} }

View File

@ -470,6 +470,7 @@ pub fn get_type_from_type_annotation_kinds(
} }
result result
}; };
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods let mut tobj_fields = methods
.iter() .iter()
.map(|(name, ty, _)| { .map(|(name, ty, _)| {

View File

@ -34,6 +34,7 @@ pub enum TypeErrorKind {
}, },
RequiresTypeAnn, RequiresTypeAnn,
PolymorphicFunctionPointer, PolymorphicFunctionPointer,
NoSuchAttribute(RecordKey, Type),
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -156,6 +157,10 @@ impl<'a> Display for DisplayTypeError<'a> {
let t = self.unifier.stringify_with_notes(*t, &mut notes); let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` field/method does not exist") write!(f, "`{t}::{name}` field/method does not exist")
} }
NoSuchAttribute(name, t) => {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` is not a class attribute")
}
TupleIndexOutOfBounds { index, len } => { TupleIndexOutOfBounds { index, len } => {
write!( write!(
f, f,

View File

@ -6,6 +6,7 @@ use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::toplevel::TopLevelDef;
use crate::{ use crate::{
symbol_resolver::{SymbolResolver, SymbolValue}, symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{ toplevel::{
@ -1441,29 +1442,13 @@ impl<'a> Inferencer<'a> {
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty })) Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
} }
fn infer_attribute( /// Checks for non-class attributes
fn infer_general_attribute(
&mut self, &mut self,
value: &ast::Expr<Option<Type>>, value: &ast::Expr<Option<Type>>,
attr: StrRef, attr: StrRef,
ctx: ExprContext, ctx: ExprContext,
) -> InferenceResult { ) -> InferenceResult {
let ty = value.custom.unwrap();
if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) {
// just a fast path
match (fields.get(&attr), ctx == ExprContext::Store) {
(Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty),
(Some((_, false)), true) => {
report_error(&format!("Field `{attr}` is immutable"), value.location)
}
(None, _) => {
let t = self.unifier.stringify(ty);
report_error(
&format!("`{t}::{attr}` field/method does not exist"),
value.location,
)
}
}
} else {
let attr_ty = self.unifier.get_dummy_var().ty; let attr_ty = self.unifier.get_dummy_var().ty;
let fields = once(( let fields = once((
attr.into(), attr.into(),
@ -1474,6 +1459,81 @@ impl<'a> Inferencer<'a> {
self.constrain(value.custom.unwrap(), record, &value.location)?; self.constrain(value.custom.unwrap(), record, &value.location)?;
Ok(attr_ty) Ok(attr_ty)
} }
fn infer_attribute(
&mut self,
value: &ast::Expr<Option<Type>>,
attr: StrRef,
ctx: ExprContext,
) -> InferenceResult {
let ty = value.custom.unwrap();
if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) {
// just a fast path
match (fields.get(&attr), ctx == ExprContext::Store) {
(Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty),
(Some((_, false)), true) => {
report_error(&format!("Field `{attr}` is immutable"), value.location)
}
(None, mutable) => {
// Check whether it is a class attribute
let defs = self.top_level.definitions.read();
let result = {
if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() {
attributes.iter().find_map(|f| {
if f.0 == attr {
return Some(f.1);
}
None
})
} else {
None
}
};
match result {
Some(res) if !mutable => Ok(res),
Some(_) => report_error(
&format!("Class Attribute `{attr}` is immutable"),
value.location,
),
None => {
let t = self.unifier.stringify(ty);
report_error(
&format!("`{t}::{attr}` field/method does not exist"),
value.location,
)
}
}
}
}
} else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) {
// Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1
let result = {
self.top_level.definitions.read().iter().find_map(|def| {
if let Some(rear_guard) = def.try_read() {
if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard {
if name.to_string() == self.unifier.stringify(sign.ret) {
return attributes.iter().find_map(|f| {
if f.0 == attr {
return Some(f.clone().1);
}
None
});
}
}
}
None
})
};
match result {
Some(f) if ctx != ExprContext::Store => Ok(f),
Some(_) => {
report_error(&format!("Class Attribute `{attr}` is immutable"), value.location)
}
None => self.infer_general_attribute(value, attr, ctx),
}
} else {
self.infer_general_attribute(value, attr, ctx)
}
} }
fn infer_bool_ops(&mut self, values: &[ast::Expr<Option<Type>>]) -> InferenceResult { fn infer_bool_ops(&mut self, values: &[ast::Expr<Option<Type>>]) -> InferenceResult {