Compare commits
4 Commits
ae48e2042d
...
3958a9ccd8
Author | SHA1 | Date |
---|---|---|
|
3958a9ccd8 | |
|
b0b804051a | |
|
134af79fd6 | |
|
7fe2c3496c |
|
@ -0,0 +1,40 @@
|
||||||
|
from min_artiq import *
|
||||||
|
from numpy import int32
|
||||||
|
|
||||||
|
|
||||||
|
@nac3
|
||||||
|
class Demo:
|
||||||
|
attr1: KernelInvariant[int32] = 2
|
||||||
|
attr2: int32 = 4
|
||||||
|
attr3: Kernel[int32]
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def __init__(self):
|
||||||
|
self.attr3 = 8
|
||||||
|
|
||||||
|
|
||||||
|
@nac3
|
||||||
|
class NAC3Devices:
|
||||||
|
core: KernelInvariant[Core]
|
||||||
|
attr4: KernelInvariant[int32] = 16
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.core = Core()
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def run(self):
|
||||||
|
Demo.attr1 # Supported
|
||||||
|
# Demo.attr2 # Field not accessible on Kernel
|
||||||
|
# Demo.attr3 # Only attributes can be accessed in this way
|
||||||
|
# Demo.attr1 = 2 # Attributes are immutable
|
||||||
|
|
||||||
|
self.attr4 # Attributes can be accessed within class
|
||||||
|
|
||||||
|
obj = Demo()
|
||||||
|
obj.attr1 # Attributes can be accessed by class objects
|
||||||
|
|
||||||
|
NAC3Devices.attr4 # Attributes accessible for classes without __init__
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
NAC3Devices().run()
|
|
@ -657,7 +657,7 @@ pub fn attributes_writeback(
|
||||||
}
|
}
|
||||||
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
|
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
|
||||||
attributes.push(name.to_string());
|
attributes.push(name.to_string());
|
||||||
let index = ctx.get_attr_index(ty, *name);
|
let (index, _) = ctx.get_attr_index(ty, *name);
|
||||||
values.push((
|
values.push((
|
||||||
*field_ty,
|
*field_ty,
|
||||||
ctx.build_gep_and_load(
|
ctx.build_gep_and_load(
|
||||||
|
|
|
@ -627,12 +627,15 @@ impl InnerResolver {
|
||||||
let pyid_to_def = self.pyid_to_def.read();
|
let pyid_to_def = self.pyid_to_def.read();
|
||||||
let constructor_ty = pyid_to_def.get(&py_obj_id).and_then(|def_id| {
|
let constructor_ty = pyid_to_def.get(&py_obj_id).and_then(|def_id| {
|
||||||
defs.iter().find_map(|def| {
|
defs.iter().find_map(|def| {
|
||||||
if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*def.read() {
|
if let Some(rear_guard) = def.try_read() {
|
||||||
if object_id == def_id
|
if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*rear_guard
|
||||||
&& constructor.is_some()
|
|
||||||
&& methods.iter().any(|(s, _, _)| s == &"__init__".into())
|
|
||||||
{
|
{
|
||||||
return *constructor;
|
if object_id == def_id
|
||||||
|
&& constructor.is_some()
|
||||||
|
&& methods.iter().any(|(s, _, _)| s == &"__init__".into())
|
||||||
|
{
|
||||||
|
return *constructor;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
|
@ -664,7 +667,29 @@ impl InnerResolver {
|
||||||
primitives,
|
primitives,
|
||||||
)? {
|
)? {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(e) => return Ok(Err(e)),
|
Err(e) => {
|
||||||
|
// Allow access to Class Attributes of Classes without having to initialize Objects
|
||||||
|
if self.pyid_to_def.read().contains_key(&py_obj_id) {
|
||||||
|
if let Some(def_id) = self.pyid_to_def.read().get(&py_obj_id).copied() {
|
||||||
|
let def = defs[def_id.0].read();
|
||||||
|
let TopLevelDef::Class { object_id, .. } = &*def else {
|
||||||
|
// only object is supported, functions are not supported
|
||||||
|
unreachable!("function type is not supported, should not be queried")
|
||||||
|
};
|
||||||
|
|
||||||
|
let ty = TypeEnum::TObj {
|
||||||
|
obj_id: *object_id,
|
||||||
|
params: VarMap::new(),
|
||||||
|
fields: HashMap::new(),
|
||||||
|
};
|
||||||
|
(unifier.add_ty(ty), true)
|
||||||
|
} else {
|
||||||
|
return Ok(Err(e));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Ok(Err(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
match (&*unifier.get_ty(extracted_ty), inst_check) {
|
match (&*unifier.get_ty(extracted_ty), inst_check) {
|
||||||
// do the instantiation for these four types
|
// do the instantiation for these four types
|
||||||
|
|
|
@ -86,19 +86,35 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
get_subst_key(&mut self.unifier, obj, &fun.vars, filter)
|
get_subst_key(&mut self.unifier, obj, &fun.vars, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize {
|
/// Checks the field and attributes of classes
|
||||||
|
/// Returns the index of attr in class fields otherwise returns the attribute value
|
||||||
|
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option<Constant>) {
|
||||||
let obj_id = match &*self.unifier.get_ty(ty) {
|
let obj_id = match &*self.unifier.get_ty(ty) {
|
||||||
TypeEnum::TObj { obj_id, .. } => *obj_id,
|
TypeEnum::TObj { obj_id, .. } => *obj_id,
|
||||||
// we cannot have other types, virtual type should be handled by function calls
|
// we cannot have other types, virtual type should be handled by function calls
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
let def = &self.top_level.definitions.read()[obj_id.0];
|
let def = &self.top_level.definitions.read()[obj_id.0];
|
||||||
let index = if let TopLevelDef::Class { fields, .. } = &*def.read() {
|
let (index, value) = if let TopLevelDef::Class { fields, attributes, .. } = &*def.read() {
|
||||||
fields.iter().find_position(|x| x.0 == attr).unwrap().0
|
if let Some(field_index) = fields.iter().find_position(|x| x.0 == attr) {
|
||||||
|
(field_index.0, None)
|
||||||
|
} else {
|
||||||
|
let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap();
|
||||||
|
(attribute_index.0, Some(attribute_index.1 .2.clone()))
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
index
|
(index, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_attr_index_object(&mut self, ty: Type, attr: StrRef) -> usize {
|
||||||
|
match &*self.unifier.get_ty(ty) {
|
||||||
|
TypeEnum::TObj { fields, .. } => {
|
||||||
|
fields.iter().find_position(|x| *x.0 == attr).unwrap().0
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn gen_symbol_val<G: CodeGenerator + ?Sized>(
|
pub fn gen_symbol_val<G: CodeGenerator + ?Sized>(
|
||||||
|
@ -2166,11 +2182,72 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
ExprKind::Attribute { value, attr, .. } => {
|
ExprKind::Attribute { value, attr, .. } => {
|
||||||
// note that we would handle class methods directly in calls
|
// note that we would handle class methods directly in calls
|
||||||
|
|
||||||
|
// Change Class attribute access requests to accessing constants from Class Definition
|
||||||
|
if let Some(c) = value.custom {
|
||||||
|
if let TypeEnum::TFunc(_) = &*ctx.unifier.get_ty(c) {
|
||||||
|
let defs = ctx.top_level.definitions.read();
|
||||||
|
let result = defs.iter().find_map(|def| {
|
||||||
|
if let Some(rear_guard) = def.try_read() {
|
||||||
|
if let TopLevelDef::Class {
|
||||||
|
constructor: Some(constructor),
|
||||||
|
attributes,
|
||||||
|
..
|
||||||
|
} = &*rear_guard
|
||||||
|
{
|
||||||
|
if *constructor == c {
|
||||||
|
return attributes.iter().find_map(|f| {
|
||||||
|
if f.0 == *attr {
|
||||||
|
// All other checks performed by this point
|
||||||
|
return Some(f.2.clone());
|
||||||
|
}
|
||||||
|
None
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
});
|
||||||
|
match result {
|
||||||
|
Some(val) => {
|
||||||
|
let mut modified_expr = expr.clone();
|
||||||
|
modified_expr.node = ExprKind::Constant { value: val, kind: None };
|
||||||
|
|
||||||
|
return generator.gen_expr(ctx, &modified_expr);
|
||||||
|
}
|
||||||
|
None => unreachable!("Function Type should not have attributes"),
|
||||||
|
}
|
||||||
|
} else if let TypeEnum::TObj { obj_id, fields, params } = &*ctx.unifier.get_ty(c) {
|
||||||
|
if fields.is_empty() && params.is_empty() {
|
||||||
|
let defs = ctx.top_level.definitions.read();
|
||||||
|
let def = defs[obj_id.0].read();
|
||||||
|
match if let TopLevelDef::Class { attributes, .. } = &*def {
|
||||||
|
attributes.iter().find_map(|f| {
|
||||||
|
if f.0 == *attr {
|
||||||
|
return Some(f.2.clone());
|
||||||
|
}
|
||||||
|
None
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
} {
|
||||||
|
Some(val) => {
|
||||||
|
let mut modified_expr = expr.clone();
|
||||||
|
modified_expr.node = ExprKind::Constant { value: val, kind: None };
|
||||||
|
|
||||||
|
return generator.gen_expr(ctx, &modified_expr);
|
||||||
|
}
|
||||||
|
None => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
match generator.gen_expr(ctx, value)? {
|
match generator.gen_expr(ctx, value)? {
|
||||||
Some(ValueEnum::Static(v)) => v.get_field(*attr, ctx).map_or_else(
|
Some(ValueEnum::Static(v)) => v.get_field(*attr, ctx).map_or_else(
|
||||||
|| {
|
|| {
|
||||||
let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
|
let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
|
||||||
let index = ctx.get_attr_index(value.custom.unwrap(), *attr);
|
let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr);
|
||||||
Ok(ValueEnum::Dynamic(ctx.build_gep_and_load(
|
Ok(ValueEnum::Dynamic(ctx.build_gep_and_load(
|
||||||
v.into_pointer_value(),
|
v.into_pointer_value(),
|
||||||
&[zero, int32.const_int(index as u64, false)],
|
&[zero, int32.const_int(index as u64, false)],
|
||||||
|
@ -2180,7 +2257,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
Ok,
|
Ok,
|
||||||
)?,
|
)?,
|
||||||
Some(ValueEnum::Dynamic(v)) => {
|
Some(ValueEnum::Dynamic(v)) => {
|
||||||
let index = ctx.get_attr_index(value.custom.unwrap(), *attr);
|
let (index, attr_value) = ctx.get_attr_index(value.custom.unwrap(), *attr);
|
||||||
|
if let Some(val) = attr_value {
|
||||||
|
// Change to Constant Construct
|
||||||
|
let mut modified_expr = expr.clone();
|
||||||
|
modified_expr.node = ExprKind::Constant { value: val, kind: None };
|
||||||
|
|
||||||
|
return generator.gen_expr(ctx, &modified_expr);
|
||||||
|
}
|
||||||
ValueEnum::Dynamic(ctx.build_gep_and_load(
|
ValueEnum::Dynamic(ctx.build_gep_and_load(
|
||||||
v.into_pointer_value(),
|
v.into_pointer_value(),
|
||||||
&[zero, int32.const_int(index as u64, false)],
|
&[zero, int32.const_int(index as u64, false)],
|
||||||
|
@ -2363,6 +2447,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
ExprKind::Attribute { value, attr, .. } => {
|
ExprKind::Attribute { value, attr, .. } => {
|
||||||
let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) };
|
let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) };
|
||||||
|
|
||||||
|
// Handle Class Method calls
|
||||||
let id = if let TypeEnum::TObj { obj_id, .. } =
|
let id = if let TypeEnum::TObj { obj_id, .. } =
|
||||||
&*ctx.unifier.get_ty(value.custom.unwrap())
|
&*ctx.unifier.get_ty(value.custom.unwrap())
|
||||||
{
|
{
|
||||||
|
|
|
@ -113,7 +113,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
ExprKind::Attribute { value, attr, .. } => {
|
ExprKind::Attribute { value, attr, .. } => {
|
||||||
let index = ctx.get_attr_index(value.custom.unwrap(), *attr);
|
let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr);
|
||||||
let val = if let Some(v) = generator.gen_expr(ctx, value)? {
|
let val = if let Some(v) = generator.gen_expr(ctx, value)? {
|
||||||
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
|
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -83,6 +83,7 @@ pub fn get_exn_constructor(
|
||||||
object_id: DefinitionId(class_id),
|
object_id: DefinitionId(class_id),
|
||||||
type_vars: Vec::default(),
|
type_vars: Vec::default(),
|
||||||
fields: exception_fields,
|
fields: exception_fields,
|
||||||
|
attributes: Vec::default(),
|
||||||
methods: vec![("__init__".into(), signature, DefinitionId(cons_id))],
|
methods: vec![("__init__".into(), signature, DefinitionId(cons_id))],
|
||||||
ancestors: vec![
|
ancestors: vec![
|
||||||
TypeAnnotation::CustomClass { id: DefinitionId(class_id), params: Vec::default() },
|
TypeAnnotation::CustomClass { id: DefinitionId(class_id), params: Vec::default() },
|
||||||
|
@ -596,6 +597,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
object_id: prim.id(),
|
object_id: prim.id(),
|
||||||
type_vars: Vec::default(),
|
type_vars: Vec::default(),
|
||||||
fields: make_exception_fields(int32, int64, str),
|
fields: make_exception_fields(int32, int64, str),
|
||||||
|
attributes: Vec::default(),
|
||||||
methods: Vec::default(),
|
methods: Vec::default(),
|
||||||
ancestors: vec![],
|
ancestors: vec![],
|
||||||
constructor: None,
|
constructor: None,
|
||||||
|
@ -624,7 +626,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
name: prim.name().into(),
|
name: prim.name().into(),
|
||||||
object_id: prim.id(),
|
object_id: prim.id(),
|
||||||
type_vars: vec![self.option_tvar.ty],
|
type_vars: vec![self.option_tvar.ty],
|
||||||
fields: vec![],
|
fields: Vec::default(),
|
||||||
|
attributes: Vec::default(),
|
||||||
methods: vec![
|
methods: vec![
|
||||||
Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0),
|
Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0),
|
||||||
Self::create_method(PrimDef::OptionIsNone, self.is_some_ty.0),
|
Self::create_method(PrimDef::OptionIsNone, self.is_some_ty.0),
|
||||||
|
@ -738,6 +741,7 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
object_id: prim.id(),
|
object_id: prim.id(),
|
||||||
type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty],
|
type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty],
|
||||||
fields: Vec::default(),
|
fields: Vec::default(),
|
||||||
|
attributes: Vec::default(),
|
||||||
methods: vec![
|
methods: vec![
|
||||||
Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0),
|
Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0),
|
||||||
Self::create_method(PrimDef::NDArrayFill, self.ndarray_fill_ty.0),
|
Self::create_method(PrimDef::NDArrayFill, self.ndarray_fill_ty.0),
|
||||||
|
|
|
@ -1057,7 +1057,14 @@ impl TopLevelComposer {
|
||||||
let (keyword_list, core_config) = core_info;
|
let (keyword_list, core_config) = core_info;
|
||||||
let mut class_def = class_def.write();
|
let mut class_def = class_def.write();
|
||||||
let TopLevelDef::Class {
|
let TopLevelDef::Class {
|
||||||
object_id, ancestors, fields, methods, resolver, type_vars, ..
|
object_id,
|
||||||
|
ancestors,
|
||||||
|
fields,
|
||||||
|
attributes,
|
||||||
|
methods,
|
||||||
|
resolver,
|
||||||
|
type_vars,
|
||||||
|
..
|
||||||
} = &mut *class_def
|
} = &mut *class_def
|
||||||
else {
|
else {
|
||||||
unreachable!("here must be toplevel class def");
|
unreachable!("here must be toplevel class def");
|
||||||
|
@ -1073,10 +1080,14 @@ impl TopLevelComposer {
|
||||||
class_body_ast,
|
class_body_ast,
|
||||||
_class_ancestor_def,
|
_class_ancestor_def,
|
||||||
class_fields_def,
|
class_fields_def,
|
||||||
|
class_attributes_def,
|
||||||
class_methods_def,
|
class_methods_def,
|
||||||
class_type_vars_def,
|
class_type_vars_def,
|
||||||
class_resolver,
|
class_resolver,
|
||||||
) = (*object_id, *name, bases, body, ancestors, fields, methods, type_vars, resolver);
|
) = (
|
||||||
|
*object_id, *name, bases, body, ancestors, fields, attributes, methods, type_vars,
|
||||||
|
resolver,
|
||||||
|
);
|
||||||
|
|
||||||
let class_resolver = class_resolver.as_ref().unwrap();
|
let class_resolver = class_resolver.as_ref().unwrap();
|
||||||
let class_resolver = class_resolver.as_ref();
|
let class_resolver = class_resolver.as_ref();
|
||||||
|
@ -1285,34 +1296,74 @@ impl TopLevelComposer {
|
||||||
.unify(method_dummy_ty, method_type)
|
.unify(method_dummy_ty, method_type)
|
||||||
.map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?;
|
.map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?;
|
||||||
}
|
}
|
||||||
ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => {
|
ast::StmtKind::AnnAssign { target, annotation, value, .. } => {
|
||||||
if let ast::ExprKind::Name { id: attr, .. } = &target.node {
|
if let ast::ExprKind::Name { id: attr, .. } = &target.node {
|
||||||
if defined_fields.insert(attr.to_string()) {
|
if defined_fields.insert(attr.to_string()) {
|
||||||
let dummy_field_type = unifier.get_dummy_var().ty;
|
let dummy_field_type = unifier.get_dummy_var().ty;
|
||||||
|
|
||||||
// handle Kernel[T], KernelInvariant[T]
|
let annotation = match value {
|
||||||
let (annotation, mutable) = match &annotation.node {
|
None => {
|
||||||
ast::ExprKind::Subscript { value, slice, .. }
|
// handle Kernel[T], KernelInvariant[T]
|
||||||
if matches!(
|
let (annotation, mutable) = match &annotation.node {
|
||||||
&value.node,
|
ast::ExprKind::Subscript { value, slice, .. }
|
||||||
ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into()
|
if matches!(
|
||||||
) =>
|
&value.node,
|
||||||
{
|
ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into()
|
||||||
(slice, false)
|
) =>
|
||||||
|
{
|
||||||
|
(slice, false)
|
||||||
|
}
|
||||||
|
ast::ExprKind::Subscript { value, slice, .. }
|
||||||
|
if matches!(
|
||||||
|
&value.node,
|
||||||
|
ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into())
|
||||||
|
) =>
|
||||||
|
{
|
||||||
|
(slice, true)
|
||||||
|
}
|
||||||
|
_ if core_config.kernel_ann.is_none() => (annotation, true),
|
||||||
|
_ => continue, // ignore fields annotated otherwise
|
||||||
|
};
|
||||||
|
class_fields_def.push((*attr, dummy_field_type, mutable));
|
||||||
|
annotation
|
||||||
}
|
}
|
||||||
ast::ExprKind::Subscript { value, slice, .. }
|
// Supporting Class Attributes
|
||||||
if matches!(
|
Some(boxed_expr) => {
|
||||||
&value.node,
|
// Class attributes are set as immutable regardless
|
||||||
ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into())
|
let (annotation, _) = match &annotation.node {
|
||||||
) =>
|
ast::ExprKind::Subscript { slice, .. } => (slice, false),
|
||||||
{
|
_ if core_config.kernel_ann.is_none() => (annotation, false),
|
||||||
(slice, true)
|
_ => continue,
|
||||||
}
|
};
|
||||||
_ if core_config.kernel_ann.is_none() => (annotation, true),
|
|
||||||
_ => continue, // ignore fields annotated otherwise
|
|
||||||
};
|
|
||||||
class_fields_def.push((*attr, dummy_field_type, mutable));
|
|
||||||
|
|
||||||
|
match &**boxed_expr {
|
||||||
|
ast::Located {location: _, custom: (), node: ast::ExprKind::Constant { value: v, kind: _ }} => {
|
||||||
|
// Restricting the types allowed to be defined as class attributes
|
||||||
|
match v {
|
||||||
|
ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {}
|
||||||
|
_ => {
|
||||||
|
return Err(HashSet::from([
|
||||||
|
format!(
|
||||||
|
"unsupported statement in class definition body (at {})",
|
||||||
|
b.location
|
||||||
|
),
|
||||||
|
]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
class_attributes_def.push((*attr, dummy_field_type, v.clone()));
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(HashSet::from([
|
||||||
|
format!(
|
||||||
|
"unsupported statement in class definition body (at {})",
|
||||||
|
b.location
|
||||||
|
),
|
||||||
|
]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
annotation
|
||||||
|
}
|
||||||
|
};
|
||||||
let parsed_annotation = parse_ast_to_type_annotation_kinds(
|
let parsed_annotation = parse_ast_to_type_annotation_kinds(
|
||||||
class_resolver,
|
class_resolver,
|
||||||
temp_def_list,
|
temp_def_list,
|
||||||
|
@ -1384,7 +1435,14 @@ impl TopLevelComposer {
|
||||||
type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>,
|
type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>,
|
||||||
) -> Result<(), HashSet<String>> {
|
) -> Result<(), HashSet<String>> {
|
||||||
let TopLevelDef::Class {
|
let TopLevelDef::Class {
|
||||||
object_id, ancestors, fields, methods, resolver, type_vars, ..
|
object_id,
|
||||||
|
ancestors,
|
||||||
|
fields,
|
||||||
|
attributes,
|
||||||
|
methods,
|
||||||
|
resolver,
|
||||||
|
type_vars,
|
||||||
|
..
|
||||||
} = class_def
|
} = class_def
|
||||||
else {
|
else {
|
||||||
unreachable!("here must be class def ast")
|
unreachable!("here must be class def ast")
|
||||||
|
@ -1393,10 +1451,11 @@ impl TopLevelComposer {
|
||||||
_class_id,
|
_class_id,
|
||||||
class_ancestor_def,
|
class_ancestor_def,
|
||||||
class_fields_def,
|
class_fields_def,
|
||||||
|
class_attribute_def,
|
||||||
class_methods_def,
|
class_methods_def,
|
||||||
_class_type_vars_def,
|
_class_type_vars_def,
|
||||||
_class_resolver,
|
_class_resolver,
|
||||||
) = (*object_id, ancestors, fields, methods, type_vars, resolver);
|
) = (*object_id, ancestors, fields, attributes, methods, type_vars, resolver);
|
||||||
|
|
||||||
// since when this function is called, the ancestors of the direct parent
|
// since when this function is called, the ancestors of the direct parent
|
||||||
// are supposed to be already handled, so we only need to deal with the direct parent
|
// are supposed to be already handled, so we only need to deal with the direct parent
|
||||||
|
@ -1407,7 +1466,7 @@ impl TopLevelComposer {
|
||||||
|
|
||||||
let base = temp_def_list.get(id.0).unwrap();
|
let base = temp_def_list.get(id.0).unwrap();
|
||||||
let base = base.read();
|
let base = base.read();
|
||||||
let TopLevelDef::Class { methods, fields, .. } = &*base else {
|
let TopLevelDef::Class { methods, fields, attributes, .. } = &*base else {
|
||||||
unreachable!("must be top level class def")
|
unreachable!("must be top level class def")
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1449,7 +1508,7 @@ impl TopLevelComposer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// use the new_child_methods to replace all the elements in `class_methods_def`
|
// use the new_child_methods to replace all the elements in `class_methods_def`
|
||||||
class_methods_def.drain(..);
|
class_methods_def.clear();
|
||||||
class_methods_def.extend(new_child_methods);
|
class_methods_def.extend(new_child_methods);
|
||||||
|
|
||||||
// handle class fields
|
// handle class fields
|
||||||
|
@ -1459,7 +1518,9 @@ impl TopLevelComposer {
|
||||||
let to_be_added = (*anc_field_name, *anc_field_ty, *mutable);
|
let to_be_added = (*anc_field_name, *anc_field_ty, *mutable);
|
||||||
// find if there is a fields with the same name in the child class
|
// find if there is a fields with the same name in the child class
|
||||||
for (class_field_name, ..) in &*class_fields_def {
|
for (class_field_name, ..) in &*class_fields_def {
|
||||||
if class_field_name == anc_field_name {
|
if class_field_name == anc_field_name
|
||||||
|
|| attributes.iter().any(|f| f.0 == *class_field_name)
|
||||||
|
{
|
||||||
return Err(HashSet::from([format!(
|
return Err(HashSet::from([format!(
|
||||||
"field `{class_field_name}` has already declared in the ancestor classes"
|
"field `{class_field_name}` has already declared in the ancestor classes"
|
||||||
)]));
|
)]));
|
||||||
|
@ -1467,14 +1528,33 @@ impl TopLevelComposer {
|
||||||
}
|
}
|
||||||
new_child_fields.push(to_be_added);
|
new_child_fields.push(to_be_added);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handle class attributes
|
||||||
|
let mut new_child_attributes: Vec<(StrRef, Type, ast::Constant)> = Vec::new();
|
||||||
|
for (anc_attr_name, anc_attr_ty, attr_value) in attributes {
|
||||||
|
let to_be_added = (*anc_attr_name, *anc_attr_ty, attr_value.clone());
|
||||||
|
// find if there is a attribute with the same name in the child class
|
||||||
|
for (class_attr_name, ..) in &*class_attribute_def {
|
||||||
|
if class_attr_name == anc_attr_name
|
||||||
|
|| fields.iter().any(|f| f.0 == *class_attr_name)
|
||||||
|
{
|
||||||
|
return Err(HashSet::from([format!(
|
||||||
|
"attribute `{class_attr_name}` has already declared in the ancestor classes"
|
||||||
|
)]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
new_child_attributes.push(to_be_added);
|
||||||
|
}
|
||||||
|
|
||||||
for (class_field_name, class_field_ty, mutable) in &*class_fields_def {
|
for (class_field_name, class_field_ty, mutable) in &*class_fields_def {
|
||||||
if !is_override.contains(class_field_name) {
|
if !is_override.contains(class_field_name) {
|
||||||
new_child_fields.push((*class_field_name, *class_field_ty, *mutable));
|
new_child_fields.push((*class_field_name, *class_field_ty, *mutable));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
class_fields_def.drain(..);
|
class_fields_def.clear();
|
||||||
class_fields_def.extend(new_child_fields);
|
class_fields_def.extend(new_child_fields);
|
||||||
|
class_attribute_def.clear();
|
||||||
|
class_attribute_def.extend(new_child_attributes);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -474,6 +474,7 @@ impl TopLevelComposer {
|
||||||
object_id: obj_id,
|
object_id: obj_id,
|
||||||
type_vars: Vec::default(),
|
type_vars: Vec::default(),
|
||||||
fields: Vec::default(),
|
fields: Vec::default(),
|
||||||
|
attributes: Vec::default(),
|
||||||
methods: Vec::default(),
|
methods: Vec::default(),
|
||||||
ancestors: Vec::default(),
|
ancestors: Vec::default(),
|
||||||
constructor,
|
constructor,
|
||||||
|
|
|
@ -103,6 +103,10 @@ pub enum TopLevelDef {
|
||||||
///
|
///
|
||||||
/// Name and type is mutable.
|
/// Name and type is mutable.
|
||||||
fields: Vec<(StrRef, Type, bool)>,
|
fields: Vec<(StrRef, Type, bool)>,
|
||||||
|
/// Class Attributes.
|
||||||
|
///
|
||||||
|
/// Name, type, value.
|
||||||
|
attributes: Vec<(StrRef, Type, ast::Constant)>,
|
||||||
/// Class methods, pointing to the corresponding function definition.
|
/// Class methods, pointing to the corresponding function definition.
|
||||||
methods: Vec<(StrRef, Type, DefinitionId)>,
|
methods: Vec<(StrRef, Type, DefinitionId)>,
|
||||||
/// Ancestor classes, including itself.
|
/// Ancestor classes, including itself.
|
||||||
|
|
|
@ -470,6 +470,7 @@ pub fn get_type_from_type_annotation_kinds(
|
||||||
}
|
}
|
||||||
result
|
result
|
||||||
};
|
};
|
||||||
|
// Class Attributes keep a copy with Class Definition and are not added to objects
|
||||||
let mut tobj_fields = methods
|
let mut tobj_fields = methods
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(name, ty, _)| {
|
.map(|(name, ty, _)| {
|
||||||
|
|
|
@ -4,15 +4,22 @@ use std::fmt::Display;
|
||||||
use crate::typecheck::typedef::TypeEnum;
|
use crate::typecheck::typedef::TypeEnum;
|
||||||
|
|
||||||
use super::typedef::{RecordKey, Type, Unifier};
|
use super::typedef::{RecordKey, Type, Unifier};
|
||||||
|
use itertools::Itertools;
|
||||||
use nac3parser::ast::{Location, StrRef};
|
use nac3parser::ast::{Location, StrRef};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum TypeErrorKind {
|
pub enum TypeErrorKind {
|
||||||
TooManyArguments {
|
GotMultipleValues {
|
||||||
expected: usize,
|
name: StrRef,
|
||||||
got: usize,
|
},
|
||||||
|
TooManyArguments {
|
||||||
|
expected_min_count: usize,
|
||||||
|
expected_max_count: usize,
|
||||||
|
got_count: usize,
|
||||||
|
},
|
||||||
|
MissingArgs {
|
||||||
|
missing_arg_names: Vec<StrRef>,
|
||||||
},
|
},
|
||||||
MissingArgs(String),
|
|
||||||
UnknownArgName(StrRef),
|
UnknownArgName(StrRef),
|
||||||
IncorrectArgType {
|
IncorrectArgType {
|
||||||
name: StrRef,
|
name: StrRef,
|
||||||
|
@ -34,6 +41,7 @@ pub enum TypeErrorKind {
|
||||||
},
|
},
|
||||||
RequiresTypeAnn,
|
RequiresTypeAnn,
|
||||||
PolymorphicFunctionPointer,
|
PolymorphicFunctionPointer,
|
||||||
|
NoSuchAttribute(RecordKey, Type),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -77,10 +85,20 @@ impl<'a> Display for DisplayTypeError<'a> {
|
||||||
use TypeErrorKind::*;
|
use TypeErrorKind::*;
|
||||||
let mut notes = Some(HashMap::new());
|
let mut notes = Some(HashMap::new());
|
||||||
match &self.err.kind {
|
match &self.err.kind {
|
||||||
TooManyArguments { expected, got } => {
|
GotMultipleValues { name } => {
|
||||||
write!(f, "Too many arguments. Expected {expected} but got {got}")
|
write!(f, "For multiple values for parameter {name}")
|
||||||
}
|
}
|
||||||
MissingArgs(args) => {
|
TooManyArguments { expected_min_count, expected_max_count, got_count } => {
|
||||||
|
debug_assert!(expected_min_count <= expected_max_count);
|
||||||
|
if expected_min_count == expected_max_count {
|
||||||
|
let expected_count = expected_min_count; // or expected_max_count
|
||||||
|
write!(f, "Too many arguments. Expected {expected_count} but got {got_count}")
|
||||||
|
} else {
|
||||||
|
write!(f, "Too many arguments. Expected {expected_min_count} to {expected_max_count} arguments but got {got_count}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MissingArgs { missing_arg_names } => {
|
||||||
|
let args = missing_arg_names.iter().join(", ");
|
||||||
write!(f, "Missing arguments: {args}")
|
write!(f, "Missing arguments: {args}")
|
||||||
}
|
}
|
||||||
UnknownArgName(name) => {
|
UnknownArgName(name) => {
|
||||||
|
@ -89,7 +107,7 @@ impl<'a> Display for DisplayTypeError<'a> {
|
||||||
IncorrectArgType { name, expected, got } => {
|
IncorrectArgType { name, expected, got } => {
|
||||||
let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
|
let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
|
||||||
let got = self.unifier.stringify_with_notes(*got, &mut notes);
|
let got = self.unifier.stringify_with_notes(*got, &mut notes);
|
||||||
write!(f, "Incorrect argument type for {name}. Expected {expected}, but got {got}")
|
write!(f, "Incorrect argument type for parameter {name}. Expected {expected}, but got {got}")
|
||||||
}
|
}
|
||||||
FieldUnificationError { field, types, loc } => {
|
FieldUnificationError { field, types, loc } => {
|
||||||
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);
|
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);
|
||||||
|
@ -156,6 +174,10 @@ impl<'a> Display for DisplayTypeError<'a> {
|
||||||
let t = self.unifier.stringify_with_notes(*t, &mut notes);
|
let t = self.unifier.stringify_with_notes(*t, &mut notes);
|
||||||
write!(f, "`{t}::{name}` field/method does not exist")
|
write!(f, "`{t}::{name}` field/method does not exist")
|
||||||
}
|
}
|
||||||
|
NoSuchAttribute(name, t) => {
|
||||||
|
let t = self.unifier.stringify_with_notes(*t, &mut notes);
|
||||||
|
write!(f, "`{t}::{name}` is not a class attribute")
|
||||||
|
}
|
||||||
TupleIndexOutOfBounds { index, len } => {
|
TupleIndexOutOfBounds { index, len } => {
|
||||||
write!(
|
write!(
|
||||||
f,
|
f,
|
||||||
|
|
|
@ -6,6 +6,7 @@ use std::{cell::RefCell, sync::Arc};
|
||||||
|
|
||||||
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
|
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
|
||||||
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
|
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
|
||||||
|
use crate::toplevel::TopLevelDef;
|
||||||
use crate::{
|
use crate::{
|
||||||
symbol_resolver::{SymbolResolver, SymbolValue},
|
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||||
toplevel::{
|
toplevel::{
|
||||||
|
@ -641,14 +642,7 @@ impl<'a> Inferencer<'a> {
|
||||||
})
|
})
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
let required: Vec<_> = sign
|
self.unifier.unify_call(&call, ty, sign).map_err(|e| {
|
||||||
.args
|
|
||||||
.iter()
|
|
||||||
.filter(|v| v.default_value.is_none())
|
|
||||||
.map(|v| v.name)
|
|
||||||
.rev()
|
|
||||||
.collect();
|
|
||||||
self.unifier.unify_call(&call, ty, sign, &required).map_err(|e| {
|
|
||||||
HashSet::from([e
|
HashSet::from([e
|
||||||
.at(Some(location))
|
.at(Some(location))
|
||||||
.to_display(self.unifier)
|
.to_display(self.unifier)
|
||||||
|
@ -1346,16 +1340,9 @@ impl<'a> Inferencer<'a> {
|
||||||
ret: sign.ret,
|
ret: sign.ret,
|
||||||
loc: Some(location),
|
loc: Some(location),
|
||||||
};
|
};
|
||||||
let required: Vec<_> = sign
|
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
|
||||||
.args
|
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
|
||||||
.iter()
|
})?;
|
||||||
.filter(|v| v.default_value.is_none())
|
|
||||||
.map(|v| v.name)
|
|
||||||
.rev()
|
|
||||||
.collect();
|
|
||||||
self.unifier.unify_call(&call, func.custom.unwrap(), sign, &required).map_err(
|
|
||||||
|e| HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]),
|
|
||||||
)?;
|
|
||||||
return Ok(Located {
|
return Ok(Located {
|
||||||
location,
|
location,
|
||||||
custom: Some(sign.ret),
|
custom: Some(sign.ret),
|
||||||
|
@ -1441,6 +1428,24 @@ impl<'a> Inferencer<'a> {
|
||||||
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
|
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Checks for non-class attributes
|
||||||
|
fn infer_general_attribute(
|
||||||
|
&mut self,
|
||||||
|
value: &ast::Expr<Option<Type>>,
|
||||||
|
attr: StrRef,
|
||||||
|
ctx: ExprContext,
|
||||||
|
) -> InferenceResult {
|
||||||
|
let attr_ty = self.unifier.get_dummy_var().ty;
|
||||||
|
let fields = once((
|
||||||
|
attr.into(),
|
||||||
|
RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)),
|
||||||
|
))
|
||||||
|
.collect();
|
||||||
|
let record = self.unifier.add_record(fields);
|
||||||
|
self.constrain(value.custom.unwrap(), record, &value.location)?;
|
||||||
|
Ok(attr_ty)
|
||||||
|
}
|
||||||
|
|
||||||
fn infer_attribute(
|
fn infer_attribute(
|
||||||
&mut self,
|
&mut self,
|
||||||
value: &ast::Expr<Option<Type>>,
|
value: &ast::Expr<Option<Type>>,
|
||||||
|
@ -1448,31 +1453,72 @@ impl<'a> Inferencer<'a> {
|
||||||
ctx: ExprContext,
|
ctx: ExprContext,
|
||||||
) -> InferenceResult {
|
) -> InferenceResult {
|
||||||
let ty = value.custom.unwrap();
|
let ty = value.custom.unwrap();
|
||||||
if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) {
|
if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) {
|
||||||
// just a fast path
|
// just a fast path
|
||||||
match (fields.get(&attr), ctx == ExprContext::Store) {
|
match (fields.get(&attr), ctx == ExprContext::Store) {
|
||||||
(Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty),
|
(Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty),
|
||||||
(Some((_, false)), true) => {
|
(Some((_, false)), true) => {
|
||||||
report_error(&format!("Field `{attr}` is immutable"), value.location)
|
report_error(&format!("Field `{attr}` is immutable"), value.location)
|
||||||
}
|
}
|
||||||
(None, _) => {
|
(None, mutable) => {
|
||||||
let t = self.unifier.stringify(ty);
|
// Check whether it is a class attribute
|
||||||
report_error(
|
let defs = self.top_level.definitions.read();
|
||||||
&format!("`{t}::{attr}` field/method does not exist"),
|
let result = {
|
||||||
value.location,
|
if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() {
|
||||||
)
|
attributes.iter().find_map(|f| {
|
||||||
|
if f.0 == attr {
|
||||||
|
return Some(f.1);
|
||||||
|
}
|
||||||
|
None
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
};
|
||||||
|
match result {
|
||||||
|
Some(res) if !mutable => Ok(res),
|
||||||
|
Some(_) => report_error(
|
||||||
|
&format!("Class Attribute `{attr}` is immutable"),
|
||||||
|
value.location,
|
||||||
|
),
|
||||||
|
None => {
|
||||||
|
let t = self.unifier.stringify(ty);
|
||||||
|
report_error(
|
||||||
|
&format!("`{t}::{attr}` field/method does not exist"),
|
||||||
|
value.location,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) {
|
||||||
|
// Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1
|
||||||
|
let result = {
|
||||||
|
self.top_level.definitions.read().iter().find_map(|def| {
|
||||||
|
if let Some(rear_guard) = def.try_read() {
|
||||||
|
if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard {
|
||||||
|
if name.to_string() == self.unifier.stringify(sign.ret) {
|
||||||
|
return attributes.iter().find_map(|f| {
|
||||||
|
if f.0 == attr {
|
||||||
|
return Some(f.clone().1);
|
||||||
|
}
|
||||||
|
None
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
})
|
||||||
|
};
|
||||||
|
match result {
|
||||||
|
Some(f) if ctx != ExprContext::Store => Ok(f),
|
||||||
|
Some(_) => {
|
||||||
|
report_error(&format!("Class Attribute `{attr}` is immutable"), value.location)
|
||||||
|
}
|
||||||
|
None => self.infer_general_attribute(value, attr, ctx),
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
let attr_ty = self.unifier.get_dummy_var().ty;
|
self.infer_general_attribute(value, attr, ctx)
|
||||||
let fields = once((
|
|
||||||
attr.into(),
|
|
||||||
RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)),
|
|
||||||
))
|
|
||||||
.collect();
|
|
||||||
let record = self.unifier.add_record(fields);
|
|
||||||
self.constrain(value.custom.unwrap(), record, &value.location)?;
|
|
||||||
Ok(attr_ty)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -289,6 +289,7 @@ impl TestEnvironment {
|
||||||
object_id: DefinitionId(i),
|
object_id: DefinitionId(i),
|
||||||
type_vars: Vec::default(),
|
type_vars: Vec::default(),
|
||||||
fields: Vec::default(),
|
fields: Vec::default(),
|
||||||
|
attributes: Vec::default(),
|
||||||
methods: Vec::default(),
|
methods: Vec::default(),
|
||||||
ancestors: Vec::default(),
|
ancestors: Vec::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
@ -331,6 +332,7 @@ impl TestEnvironment {
|
||||||
object_id: DefinitionId(defs + 1),
|
object_id: DefinitionId(defs + 1),
|
||||||
type_vars: vec![tvar.ty],
|
type_vars: vec![tvar.ty],
|
||||||
fields: [("a".into(), tvar.ty, true)].into(),
|
fields: [("a".into(), tvar.ty, true)].into(),
|
||||||
|
attributes: Vec::default(),
|
||||||
methods: Vec::default(),
|
methods: Vec::default(),
|
||||||
ancestors: Vec::default(),
|
ancestors: Vec::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
@ -365,6 +367,7 @@ impl TestEnvironment {
|
||||||
object_id: DefinitionId(defs + 2),
|
object_id: DefinitionId(defs + 2),
|
||||||
type_vars: Vec::default(),
|
type_vars: Vec::default(),
|
||||||
fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(),
|
fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(),
|
||||||
|
attributes: Vec::default(),
|
||||||
methods: Vec::default(),
|
methods: Vec::default(),
|
||||||
ancestors: Vec::default(),
|
ancestors: Vec::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
@ -393,6 +396,7 @@ impl TestEnvironment {
|
||||||
object_id: DefinitionId(defs + 3),
|
object_id: DefinitionId(defs + 3),
|
||||||
type_vars: Vec::default(),
|
type_vars: Vec::default(),
|
||||||
fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(),
|
fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(),
|
||||||
|
attributes: Vec::default(),
|
||||||
methods: Vec::default(),
|
methods: Vec::default(),
|
||||||
ancestors: Vec::default(),
|
ancestors: Vec::default(),
|
||||||
resolver: None,
|
resolver: None,
|
||||||
|
|
|
@ -89,6 +89,13 @@ pub struct FuncArg {
|
||||||
pub default_value: Option<SymbolValue>,
|
pub default_value: Option<SymbolValue>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FuncArg {
|
||||||
|
#[must_use]
|
||||||
|
pub fn is_required(&self) -> bool {
|
||||||
|
self.default_value.is_none()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct FunSignature {
|
pub struct FunSignature {
|
||||||
pub args: Vec<FuncArg>,
|
pub args: Vec<FuncArg>,
|
||||||
|
@ -562,61 +569,153 @@ impl Unifier {
|
||||||
call: &Call,
|
call: &Call,
|
||||||
b: Type,
|
b: Type,
|
||||||
signature: &FunSignature,
|
signature: &FunSignature,
|
||||||
required: &[StrRef],
|
|
||||||
) -> Result<(), TypeError> {
|
) -> Result<(), TypeError> {
|
||||||
|
/*
|
||||||
|
NOTE: scenarios to consider:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def func1(x: int32, y: int32, z: int32 = 5): pass
|
||||||
|
|
||||||
|
# Normal scenarios
|
||||||
|
func1(23, 45) # OK, z has default
|
||||||
|
func1(23, 45, 67) # OK, z's default is overwritten
|
||||||
|
func1(x = 23, y = 45) # OK, user is using kwargs to set positional args
|
||||||
|
func1(y = 45, x = 23) # OK, kwargs order doesn't matter
|
||||||
|
|
||||||
|
# Error scenarios
|
||||||
|
func1() # ERROR: Missing arguments: x, y
|
||||||
|
func1(23) # ERROR: Missing arguments: y
|
||||||
|
func1(z = 23) # ERROR: Missing arguments: x, y
|
||||||
|
func1(x = 23) # ERROR: Missing arguments: y
|
||||||
|
func1(23, 45, x = 5) # ERROR: Got multiple values for x
|
||||||
|
func1(23, 45, x = 5, y = 6) # ERROR: Got multiple values for x (y too but Python does not report it)
|
||||||
|
func1(23, 45, 67, z = 89) # ERROR: Got multiple values for z
|
||||||
|
func1(23, 45, 67, 89) # ERROR: Function only takes from 2 to 3 positional arguments but 4 were given.
|
||||||
|
func1(23, 45, 67, w = 3) # ERROR: Got an unexpected keyword argument 'w'
|
||||||
|
|
||||||
|
# Error scenarios that do not need to be handled here.
|
||||||
|
func1(23, 45, z = 67, z = 89) # ERROR: Keyword argument repeated: z, the parser panics on this.
|
||||||
|
```
|
||||||
|
*/
|
||||||
|
|
||||||
|
struct ParamInfo<'a> {
|
||||||
|
/// Has this parameter been supplied with an argument already?
|
||||||
|
has_been_supplied: bool,
|
||||||
|
/// The corresponding [`FuncArg`] instance of this parameter (for fast table lookups)
|
||||||
|
param: &'a FuncArg,
|
||||||
|
}
|
||||||
|
|
||||||
let snapshot = self.unification_table.get_snapshot();
|
let snapshot = self.unification_table.get_snapshot();
|
||||||
if self.snapshot.is_none() {
|
if self.snapshot.is_none() {
|
||||||
self.snapshot = Some(snapshot);
|
self.snapshot = Some(snapshot);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get details about the function signature/parameters.
|
||||||
|
let num_params = signature.args.len();
|
||||||
|
|
||||||
|
// Force the type vars in `b` and `signature' to be up-to-date.
|
||||||
|
let b = self.instantiate_fun(b, signature);
|
||||||
|
let TypeEnum::TFunc(signature) = &*self.get_ty(b) else { unreachable!() };
|
||||||
|
|
||||||
|
// Get details about the input arguments
|
||||||
let Call { posargs, kwargs, ret, fun, loc } = call;
|
let Call { posargs, kwargs, ret, fun, loc } = call;
|
||||||
let instantiated = self.instantiate_fun(b, signature);
|
let num_args = posargs.len() + kwargs.len();
|
||||||
let r = self.get_ty(instantiated);
|
|
||||||
let r = r.as_ref();
|
// Now we check the arguments against the parameters
|
||||||
let TypeEnum::TFunc(signature) = r else { unreachable!() };
|
|
||||||
// we check to make sure that all required arguments (those without default
|
// Helper lambdas
|
||||||
// arguments) are provided, and do not provide the same argument twice.
|
let mut type_check_arg = |param_name, expected_arg_ty, arg_ty| {
|
||||||
let mut required = required.to_vec();
|
let ok = self.unify_impl(expected_arg_ty, arg_ty, false).is_ok();
|
||||||
let mut all_names: Vec<_> = signature.args.iter().map(|v| (v.name, v.ty)).rev().collect();
|
if ok {
|
||||||
for (i, t) in posargs.iter().enumerate() {
|
Ok(())
|
||||||
if signature.args.len() <= i {
|
} else {
|
||||||
|
// Typecheck failed, throw an error.
|
||||||
self.restore_snapshot();
|
self.restore_snapshot();
|
||||||
return Err(TypeError::new(
|
Err(TypeError::new(
|
||||||
TypeErrorKind::TooManyArguments {
|
TypeErrorKind::IncorrectArgType {
|
||||||
expected: signature.args.len(),
|
name: param_name,
|
||||||
got: posargs.len() + kwargs.len(),
|
expected: expected_arg_ty,
|
||||||
|
got: arg_ty,
|
||||||
},
|
},
|
||||||
*loc,
|
*loc,
|
||||||
));
|
))
|
||||||
}
|
}
|
||||||
required.pop();
|
};
|
||||||
let (name, expected) = all_names.pop().unwrap();
|
|
||||||
self.unify_impl(expected, *t, false).map_err(|_| {
|
// Check for "too many arguments"
|
||||||
self.restore_snapshot();
|
if num_params < posargs.len() {
|
||||||
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
let expected_min_count =
|
||||||
})?;
|
signature.args.iter().filter(|param| param.is_required()).count();
|
||||||
}
|
let expected_max_count = num_params;
|
||||||
for (k, t) in kwargs {
|
|
||||||
if let Some(i) = required.iter().position(|v| v == k) {
|
|
||||||
required.remove(i);
|
|
||||||
}
|
|
||||||
let i = all_names.iter().position(|v| &v.0 == k).ok_or_else(|| {
|
|
||||||
self.restore_snapshot();
|
|
||||||
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
|
|
||||||
})?;
|
|
||||||
let (name, expected) = all_names.remove(i);
|
|
||||||
self.unify_impl(expected, *t, false).map_err(|_| {
|
|
||||||
self.restore_snapshot();
|
|
||||||
TypeError::new(TypeErrorKind::IncorrectArgType { name, expected, got: *t }, *loc)
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
if !required.is_empty() {
|
|
||||||
self.restore_snapshot();
|
self.restore_snapshot();
|
||||||
return Err(TypeError::new(
|
return Err(TypeError::new(
|
||||||
TypeErrorKind::MissingArgs(required.iter().join(", ")),
|
TypeErrorKind::TooManyArguments {
|
||||||
|
expected_min_count,
|
||||||
|
expected_max_count,
|
||||||
|
got_count: num_args,
|
||||||
|
},
|
||||||
*loc,
|
*loc,
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: order of `param_info_by_name` is leveraged, so use an IndexMap
|
||||||
|
let mut param_info_by_name: IndexMap<StrRef, ParamInfo> = signature
|
||||||
|
.args
|
||||||
|
.iter()
|
||||||
|
.map(|arg| (arg.name, ParamInfo { has_been_supplied: false, param: arg }))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Now consume all positional arguments and typecheck them.
|
||||||
|
for (&arg_ty, param) in zip(posargs, signature.args.iter()) {
|
||||||
|
// We will also use this opportunity to mark the corresponding `param_info` as having been supplied.
|
||||||
|
let param_info = param_info_by_name.get_mut(¶m.name).unwrap();
|
||||||
|
param_info.has_been_supplied = true;
|
||||||
|
|
||||||
|
// Typecheck
|
||||||
|
type_check_arg(param.name, param.ty, arg_ty)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now consume all keyword arguments and typecheck them.
|
||||||
|
for (¶m_name, &arg_ty) in kwargs {
|
||||||
|
// We will also use this opportunity to check if this keyword argument is "legal".
|
||||||
|
|
||||||
|
let Some(param_info) = param_info_by_name.get_mut(¶m_name) else {
|
||||||
|
self.restore_snapshot();
|
||||||
|
return Err(TypeError::new(TypeErrorKind::UnknownArgName(param_name), *loc));
|
||||||
|
};
|
||||||
|
|
||||||
|
if param_info.has_been_supplied {
|
||||||
|
// NOTE: Duplicate keyword argument (i.e., `hello(1, 2, 3, arg = 4, arg = 5)`)
|
||||||
|
// is IMPOSSIBLE as the parser would have already failed.
|
||||||
|
// We only have to care about "got multiple values for XYZ"
|
||||||
|
|
||||||
|
self.restore_snapshot();
|
||||||
|
return Err(TypeError::new(
|
||||||
|
TypeErrorKind::GotMultipleValues { name: param_name },
|
||||||
|
*loc,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
param_info.has_been_supplied = true;
|
||||||
|
|
||||||
|
// Typecheck
|
||||||
|
type_check_arg(param_name, param_info.param.ty, arg_ty)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// After checking posargs and kwargs, check if there are any
|
||||||
|
// unsupported required parameters, and throw an error if they exist.
|
||||||
|
let missing_arg_names = param_info_by_name
|
||||||
|
.values()
|
||||||
|
.filter(|param_info| param_info.param.is_required() && !param_info.has_been_supplied)
|
||||||
|
.map(|param_info| param_info.param.name)
|
||||||
|
.collect_vec();
|
||||||
|
if !missing_arg_names.is_empty() {
|
||||||
|
self.restore_snapshot();
|
||||||
|
return Err(TypeError::new(TypeErrorKind::MissingArgs { missing_arg_names }, *loc));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, check the Call's return type
|
||||||
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
|
self.unify_impl(*ret, signature.ret, false).map_err(|mut err| {
|
||||||
self.restore_snapshot();
|
self.restore_snapshot();
|
||||||
if err.loc.is_none() {
|
if err.loc.is_none() {
|
||||||
|
@ -624,7 +723,8 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
err
|
err
|
||||||
})?;
|
})?;
|
||||||
*fun.borrow_mut() = Some(instantiated);
|
|
||||||
|
*fun.borrow_mut() = Some(b);
|
||||||
|
|
||||||
self.discard_snapshot(snapshot);
|
self.discard_snapshot(snapshot);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -990,17 +1090,10 @@ impl Unifier {
|
||||||
self.unification_table.set_value(b, Rc::new(TCall(calls)));
|
self.unification_table.set_value(b, Rc::new(TCall(calls)));
|
||||||
}
|
}
|
||||||
(TCall(calls), TFunc(signature)) => {
|
(TCall(calls), TFunc(signature)) => {
|
||||||
let required: Vec<StrRef> = signature
|
|
||||||
.args
|
|
||||||
.iter()
|
|
||||||
.filter(|v| v.default_value.is_none())
|
|
||||||
.map(|v| v.name)
|
|
||||||
.rev()
|
|
||||||
.collect();
|
|
||||||
// we unify every calls to the function signature.
|
// we unify every calls to the function signature.
|
||||||
for c in calls {
|
for c in calls {
|
||||||
let call = self.calls[c.0].clone();
|
let call = self.calls[c.0].clone();
|
||||||
self.unify_call(&call, b, signature, &required)?;
|
self.unify_call(&call, b, signature)?;
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue