From 31a8b17a84cf52c88c83fcf44486f45f8a363b8d Mon Sep 17 00:00:00 2001 From: abdul124 Date: Fri, 12 Jul 2024 09:58:33 +0800 Subject: [PATCH] WIP --- flake.nix | 4 +- nac3artiq/demo/demo.py | 42 +- nac3core/src/toplevel/composer.rs | 1416 +++++++++++++++++++++++------ 3 files changed, 1167 insertions(+), 295 deletions(-) diff --git a/flake.nix b/flake.nix index 4febca24..3381c9b7 100644 --- a/flake.nix +++ b/flake.nix @@ -160,8 +160,10 @@ cargo-insta clippy pre-commit - rustfmt + rust-analyzer ]; + # https://nixos.wiki/wiki/Rust#Shell.nix_example + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; }; devShells.x86_64-linux.msys2 = pkgs.mkShell { name = "nac3-dev-shell-msys2"; diff --git a/nac3artiq/demo/demo.py b/nac3artiq/demo/demo.py index aa135757..2bf538e8 100644 --- a/nac3artiq/demo/demo.py +++ b/nac3artiq/demo/demo.py @@ -1,26 +1,46 @@ +from numpy import int64 from min_artiq import * +@nac3 +class A: + pass +@nac3 +class B(A): + pass +@nac3 +class C(B): + pass + +@kernel +def test(): + pass @nac3 class Demo: - core: KernelInvariant[Core] - led0: KernelInvariant[TTLOut] - led1: KernelInvariant[TTLOut] + core: Kernel[Core] + led0: Kernel[TTLOut] + optt: Kernel[float] + t_mu: Kernel[int64] def __init__(self): self.core = Core() self.led0 = TTLOut(self.core, 18) - self.led1 = TTLOut(self.core, 19) + self.optt = 0.0 + self.t_mu = int64(0) + + @kernel + def set_time_kernel(self, t: float): + self.optt = t + self.t_mu = self.core.seconds_to_mu(self.optt) @kernel def run(self): - self.core.reset() - while True: - with parallel: - self.led0.pulse(100.*ms) - self.led1.pulse(100.*ms) - self.core.delay(100.*ms) + for t in [20.*us, 15.*us, 10.*us, 5.*us, 2.*us, 1.*us]: + self.set_time_kernel(t) + self.core.reset() + for _ in range(10000): + self.led0.pulse(self.optt) if __name__ == "__main__": - Demo().run() + Demo().run() \ No newline at end of file diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 5ba07df5..c429ef3d 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -368,8 +368,9 @@ impl TopLevelComposer { } pub fn start_analysis(&mut self, inference: bool) -> Result<(), HashSet> { - self.analyze_top_level_class_type_var()?; - self.analyze_top_level_class_bases()?; + // self.analyze_top_level_class_type_var()?; + // self.analyze_top_level_class_bases()?; + self.alternative_analyze()?; self.analyze_top_level_class_fields_methods()?; self.analyze_top_level_function()?; if inference { @@ -378,302 +379,201 @@ impl TopLevelComposer { Ok(()) } - /// step 1, analyze the type vars associated with top level class - fn analyze_top_level_class_type_var(&mut self) -> Result<(), HashSet> { + fn analyze_type_vars( + temp_def_list: &[Arc>], + unifier: &mut Unifier, + primitives_store: &PrimitiveStore, + class_def: &Arc>, + class_ast: &Option, + direct_parent: &mut HashMap>, + ) -> Result<(), HashSet> { + let mut class_def = class_def.write(); + let (class_bases, class_def_id, class_type_vars, class_resolver) = { + if let TopLevelDef::Class { object_id, type_vars, resolver, .. } = &mut *class_def { + let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) = + class_ast + else { + unreachable!() + }; + + (bases, *object_id, type_vars, resolver.as_ref().unwrap().as_ref()) + } else { + return Ok(()); + } + }; + + let mut is_generic = false; + direct_parent.insert(class_def_id, None); // Initialization + for b in class_bases { + if let ast::ExprKind::Subscript { value, slice, .. } = &b.node { + if matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Generic".into()) { + if is_generic { + return Err(HashSet::from([format!("only single Generic[...] is allowed (at {})", b.location)])); + } + is_generic = true; + let type_var_list: Vec<&ast::Expr<()>>; + // if `class A(Generic[T, V, G])` + if let ast::ExprKind::Tuple { elts, .. } = &slice.node { + type_var_list = elts.iter().collect_vec(); + // `class A(Generic[T])` + } else { + type_var_list = vec![slice.as_ref()]; + } + + // parse the type vars + let type_vars = type_var_list + .into_iter() + .map(|e| { + class_resolver.parse_type_annotation( + &temp_def_list, + unifier, + primitives_store, + e, + ) + }) + .collect::, _>>()?; + + // check if all are unique type vars + let all_unique_type_var = { + let mut occurred_type_var_id: HashSet = HashSet::new(); + type_vars.iter().all(|x| { + let ty = unifier.get_ty(*x); + if let TypeEnum::TVar { id, .. } = ty.as_ref() { + occurred_type_var_id.insert(*id) + } else { + false + } + }) + }; + + if !all_unique_type_var { + return Err(HashSet::from([format!("duplicate type variable occurs (at {})", slice.location)])); + } + + // add to TopLevelDef + class_type_vars.extend(type_vars); + continue; + } + } + if direct_parent.get(&class_def_id).unwrap().is_none(){ + let base_ty = parse_ast_to_type_annotation_kinds( + class_resolver, + &temp_def_list, + unifier, + primitives_store, + b, + vec![(class_def_id, class_type_vars.clone())] + .into_iter() + .collect::>(), + )?; + if let TypeAnnotation::CustomClass { id, .. } = &base_ty { + direct_parent.insert(class_def_id, Some((*id, base_ty))); + } else { + return Err(HashSet::from([format!("class base declaration can only be custom class (at {})", b.location)])); + } + } else { + return Err(HashSet::from([format!("a class definition can only have at most one base class declaration and one generic declaration (at {})", b.location)])); + } + } + Ok(()) + } + + fn analyze_ancestors( + class_def: &Arc>, + class_ast: &Option, + direct_parent: HashMap>, + ) -> Result<(), HashSet> { + let mut class_def = class_def.write(); + let (class_def_id, class_type_vars, class_ancestors) = { + if let TopLevelDef::Class { object_id, type_vars, ancestors, .. } = &mut *class_def { + (*object_id, type_vars, ancestors) + } else { + return Ok(()) + } + }; + let mut ancestor = direct_parent[&class_def_id].clone(); + while !ancestor.is_none() { + let Some((ancestor_id, ann)) = ancestor else {unreachable!()}; + // Special Check for classes that inherit from Exceptions + if ancestor_id.0 == 7 { + let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node else { + unreachable!() + }; + + for stmt in body { + if matches!(stmt.node, ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. }) { + return Err(HashSet::from(["Classes inherited from exception should have no custom fields/methods".into(),])); + } + } + } + // A(B) B(C) D(C), => A->B, B->C, D->C => + class_ancestors.push(ann); + + // Check for circular reference + if ancestor_id == class_def_id { + return Err(HashSet::from(["cyclic inheritance detected".into()])); + } + ancestor = direct_parent[&ancestor_id].clone(); + } + + // Every Class Inherits from itself + class_ancestors.insert(0, make_self_type_annotation(class_type_vars.as_slice(), class_def_id)); + Ok(()) + + } + + fn analyze_class_methods_fields ( + temp_def_list: &[Arc>], + unifier: &mut Unifier, + primitives_store: &PrimitiveStore, + class_def: &Arc>, + class_ast: &Option, + type_var_to_concrete_def: &mut HashMap, + core_info: (&HashSet, &ComposerConfig), + ) { + + } + fn alternative_analyze(&mut self) -> Result<(), HashSet> { let def_list = &self.definition_ast_list; let temp_def_list = self.extract_def_list(); let unifier = self.unifier.borrow_mut(); let primitives_store = &self.primitives_ty; - - let mut analyze = |class_def: &Arc>, class_ast: &Option| { - // only deal with class def here - let mut class_def = class_def.write(); - let (class_bases_ast, class_def_type_vars, class_resolver) = { - if let TopLevelDef::Class { type_vars, resolver, .. } = &mut *class_def { - let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) = - class_ast - else { - unreachable!() - }; - - (bases, type_vars, resolver) - } else { - return Ok(()); - } - }; - let class_resolver = class_resolver.as_ref().unwrap(); - let class_resolver = &**class_resolver; - - let mut is_generic = false; - for b in class_bases_ast { - match &b.node { - // analyze typevars bounded to the class, - // only support things like `class A(Generic[T, V])`, - // things like `class A(Generic[T, V, ImportedModule.T])` is not supported - // i.e. only simple names are allowed in the subscript - // should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params - ast::ExprKind::Subscript { value, slice, .. } - if { - matches!( - &value.node, - ast::ExprKind::Name { id, .. } if id == &"Generic".into() - ) - } => - { - if is_generic { - return Err(HashSet::from([format!( - "only single Generic[...] is allowed (at {})", - b.location - )])); - } - is_generic = true; - - let type_var_list: Vec<&ast::Expr<()>>; - // if `class A(Generic[T, V, G])` - if let ast::ExprKind::Tuple { elts, .. } = &slice.node { - type_var_list = elts.iter().collect_vec(); - // `class A(Generic[T])` - } else { - type_var_list = vec![&**slice]; - } - - // parse the type vars - let type_vars = type_var_list - .into_iter() - .map(|e| { - class_resolver.parse_type_annotation( - &temp_def_list, - unifier, - primitives_store, - e, - ) - }) - .collect::, _>>()?; - - // check if all are unique type vars - let all_unique_type_var = { - let mut occurred_type_var_id: HashSet = HashSet::new(); - type_vars.iter().all(|x| { - let ty = unifier.get_ty(*x); - if let TypeEnum::TVar { id, .. } = ty.as_ref() { - occurred_type_var_id.insert(*id) - } else { - false - } - }) - }; - if !all_unique_type_var { - return Err(HashSet::from([format!( - "duplicate type variable occurs (at {})", - slice.location - )])); - } - - // add to TopLevelDef - class_def_type_vars.extend(type_vars); - } - - // if others, do nothing in this function - _ => continue, - } - } - Ok(()) - }; + let mut direct_parent: HashMap> = HashMap::default(); + let mut class_def: Vec<(&Arc>, &Option)> = Vec::default(); let mut errors = HashSet::new(); - for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) { - if class_ast.is_none() { - continue; - } - if let Err(e) = analyze(class_def, class_ast) { - errors.extend(e); - } - } - if !errors.is_empty() { - return Err(errors); - } - Ok(()) - } - /// step 2, base classes. - /// now that the type vars of all classes are done, handle base classes and - /// put Self class into the ancestors list. We only allow single inheritance - fn analyze_top_level_class_bases(&mut self) -> Result<(), HashSet> { - if self.unifier.top_level.is_none() { - let ctx = Arc::new(self.make_top_level_context()); - self.unifier.top_level = Some(ctx); - } - - let temp_def_list = self.extract_def_list(); - let unifier = self.unifier.borrow_mut(); - let primitive_types = self.primitives_ty; - - let mut get_direct_parents = - |class_def: &Arc>, class_ast: &Option| { - let mut class_def = class_def.write(); - let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = { - if let TopLevelDef::Class { - ancestors, resolver, object_id, type_vars, .. - } = &mut *class_def - { - let Some(ast::Located { - node: ast::StmtKind::ClassDef { bases, .. }, .. - }) = class_ast - else { - unreachable!() - }; - - (object_id, bases, ancestors, resolver, type_vars) - } else { - return Ok(()); - } - }; - let class_resolver = class_resolver.as_ref().unwrap(); - let class_resolver = &**class_resolver; - - let mut has_base = false; - for b in class_bases { - // type vars have already been handled, so skip on `Generic[...]` - if matches!( - &b.node, - ast::ExprKind::Subscript { value, .. } - if matches!( - &value.node, - ast::ExprKind::Name { id, .. } if id == &"Generic".into() - ) - ) { - continue; - } - - if has_base { - return Err(HashSet::from([format!( - "a class definition can only have at most one base class \ - declaration and one generic declaration (at {})", - b.location - )])); - } - has_base = true; - - // the function parse_ast_to make sure that no type var occurred in - // bast_ty if it is a CustomClassKind - let base_ty = parse_ast_to_type_annotation_kinds( - class_resolver, - &temp_def_list, - unifier, - &primitive_types, - b, - vec![(*class_def_id, class_type_vars.clone())] - .into_iter() - .collect::>(), - )?; - - if let TypeAnnotation::CustomClass { .. } = &base_ty { - class_ancestors.push(base_ty); - } else { - return Err(HashSet::from([format!( - "class base declaration can only be custom class (at {})", - b.location, - )])); - } - } - Ok(()) - }; - - // first, only push direct parent into the list - let mut errors = HashSet::new(); - for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) { - if class_ast.is_none() { - continue; - } - if let Err(e) = get_direct_parents(class_def, class_ast) { - errors.extend(e); - } - } - if !errors.is_empty() { - return Err(errors); - } - - // second, get all ancestors - let mut ancestors_store: HashMap> = HashMap::default(); - let mut get_all_ancestors = - |class_def: &Arc>| -> Result<(), HashSet> { - let class_def = class_def.read(); - let (class_ancestors, class_id) = { - if let TopLevelDef::Class { ancestors, object_id, .. } = &*class_def { - (ancestors, *object_id) - } else { - return Ok(()); - } - }; - ancestors_store.insert( - class_id, - // if class has direct parents, get all ancestors of its parents. Else just empty - if class_ancestors.is_empty() { - vec![] - } else { - Self::get_all_ancestors_helper( - &class_ancestors[0], - temp_def_list.as_slice(), - )? - }, - ); - Ok(()) - }; - for (class_def, ast) in self.definition_ast_list.iter().skip(self.builtin_num) { + // Separate Class def + for (def, ast) in def_list.iter().skip(self.builtin_num) { + // Allow empty classes if ast.is_none() { continue; } - if let Err(e) = get_all_ancestors(class_def) { - errors.extend(e); + // Check if its a class definition + if let TopLevelDef::Class { .. } = &*def.read() { + class_def.push((def, ast)); + } else { + // Later } } + + // check 1 + for def in class_def.iter() { + if let Err(e) = Self::analyze_type_vars(&temp_def_list, unifier, primitives_store, def.0, def.1, &mut direct_parent){ + errors.extend(e); + }; + } if !errors.is_empty() { return Err(errors); } - - // insert the ancestors to the def list - for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) { - if class_ast.is_none() { - continue; - } - let mut class_def = class_def.write(); - let (class_ancestors, class_id, class_type_vars) = { - if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = &mut *class_def - { - (ancestors, *object_id, type_vars) - } else { - continue; - } + + // Check 2 + for def in class_def.iter() { + if let Err(e) = Self::analyze_ancestors(def.0, def.1, direct_parent.clone()){ + errors.extend(e); }; - - let ans = ancestors_store.get_mut(&class_id).unwrap(); - class_ancestors.append(ans); - - // insert self type annotation to the front of the vector to maintain the order - class_ancestors - .insert(0, make_self_type_annotation(class_type_vars.as_slice(), class_id)); - - // special case classes that inherit from Exception - if class_ancestors - .iter() - .any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) - { - // if inherited from Exception, the body should be a pass - let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node else { - unreachable!() - }; - - for stmt in body { - if matches!( - stmt.node, - ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. } - ) { - return Err(HashSet::from([ - "Classes inherited from exception should have no custom fields/methods" - .into(), - ])); - } - } - } + } + if !errors.is_empty() { + return Err(errors); } // deal with ancestor of Exception object @@ -686,8 +586,646 @@ impl TopLevelComposer { assert_eq!(*name, "Exception".into()); ancestors.push(make_self_type_annotation(&[], *object_id)); + + + // Check 3 analyze_top_level_class_fields_methods + for def in class_def.iter() { + let TopLevelDef::Class { object_id, type_vars: class_type_vars, fields: class_fields_def, attributes: class_attributes_def, methods: class_methods_def, ancestors: class_ancestors, resolver, .. } = &*def.0.write() else { + unreachable!() + }; + let ast::StmtKind::ClassDef { name, bases: class_bases_ast, body: class_body_ast, .. } = &def.1.as_ref().unwrap().node else {unreachable!()}; + let class_name = *name; + let class_id = *object_id; + let class_resolver = resolver.as_ref().unwrap().as_ref(); + + let mut defined_fields: HashSet = HashSet::new(); + for b in class_body_ast { + match &b.node { + ast::StmtKind::FunctionDef { name, args, returns, .. } => { + // Get argument type + let arg_types = { + let (arg, arg_default) = (&args.args, &args.defaults); + let arg_default = std::iter::repeat(None).take(arg.len() - arg_default.len()).chain(arg_default.iter().map(|x| Some(x))); + + // All Class functions must have self and unique parameter names, parameter names should not be same as keywords, and should be type annotated + let mut defined_parameter_name: HashSet = HashSet::new(); + let mut has_self = false; + for x in arg.iter() { + if x.node.arg == "self".into() { + has_self = true; + } else if self.keyword_list.contains(&x.node.arg) { + return Err(HashSet::from([format!("function parameter names should not be the same as the keywords (at {})", x.location)])) + } else if !defined_parameter_name.insert(x.node.arg.into()) { + return Err(HashSet::from([format!("function parameter names should be unique (at {})", x.location)])) + } else if x.node.annotation.is_none() { + return Err(HashSet::from([format!("type annotation needed for `{}` at {}",x.node.arg, x.location)])) + } + } + if !has_self { + return Err(HashSet::from([format!("class method must have a `self` parameter (at {})", b.location),])) + } + + // Get type annotation of parameters + for (x, d) in arg.iter().zip(arg_default) { + // Annotation and default with self ignored? + if x.node.arg == "self".into() { + if let Some(k) = d { + + } + continue; + } + let type_ann = parse_ast_to_type_annotation_kinds( + class_resolver, + &temp_def_list, + unifier, + primitives_store, + x.node.annotation.clone().unwrap().as_ref(), + vec![(class_id, class_type_vars.clone())] + .into_iter() + .collect::>(), + )?; + + // find type vars within this method parameter type annotation + let type_vars_within = get_type_var_contained_in_type_annotation(&type_ann); + + } + }; + + // let (method_dummy_ty, method_id) = + // Self::get_class_method_def_info(class_methods_def, *name)?; + + // let mut method_var_map = VarMap::new(); + + // let arg_types: Vec = { + // // check method parameters cannot have same name + // let mut defined_parameter_name: HashSet<_> = HashSet::new(); + // let zelf: StrRef = "self".into(); + + // let mut result = Vec::new(); + + // let arg_with_default: Vec<( + // &ast::Located>, + // Option<&ast::Expr>, + // )> = args + // .args + // .iter() + // .rev() + // .zip( + // args.defaults + // .iter() + // .rev() + // .map(|x| -> Option<&ast::Expr> { Some(x) }) + // .chain(std::iter::repeat(None)), + // ) + // .collect_vec(); + + // for (x, default) in arg_with_default.into_iter().rev() { + // let name = x.node.arg; + // if name != zelf { + // let type_ann = { + // let annotation_expr = x + // .node + // .annotation + // .as_ref() + // .ok_or_else(|| HashSet::from([ + // format!( + // "type annotation needed for `{}` at {}", + // x.node.arg, x.location + // ), + // ]))? + // .as_ref(); + // parse_ast_to_type_annotation_kinds( + // class_resolver, + // &temp_def_list, + // unifier, + // primitives, + // annotation_expr, + // vec![(class_id, class_type_vars_def.clone())] + // .into_iter() + // .collect::>(), + // )? + // }; + // // find type vars within this method parameter type annotation + // let type_vars_within = + // get_type_var_contained_in_type_annotation(&type_ann); + // // handle the class type var and the method type var + // for type_var_within in type_vars_within { + // let TypeAnnotation::TypeVar(ty) = type_var_within else { + // unreachable!("must be type var annotation") + // }; + + // let id = Self::get_var_id(ty, unifier)?; + // if let Some(prev_ty) = method_var_map.insert(id, ty) { + // // if already in the list, make sure they are the same? + // assert_eq!(prev_ty, ty); + // } + // } + // // finish handling type vars + // let dummy_func_arg = FuncArg { + // name, + // ty: unifier.get_dummy_var().ty, + // default_value: match default { + // None => None, + // Some(default) => { + // if name == "self".into() { + // return Err(HashSet::from([ + // format!("`self` parameter cannot take default value (at {})", x.location), + // ])); + // } + // Some({ + // let v = Self::parse_parameter_default_value( + // default, + // class_resolver, + // )?; + // Self::check_default_param_type( + // &v, &type_ann, primitives, unifier, + // ) + // .map_err(|err| HashSet::from([ + // format!("{} (at {})", err, x.location), + // ]))?; + // v + // }) + // } + // }, + // }; + // // push the dummy type and the type annotation + // // into the list for later unification + // type_var_to_concrete_def + // .insert(dummy_func_arg.ty, type_ann.clone()); + // result.push(dummy_func_arg); + // } + // } + // result + // }; + + // let ret_type = { + // if let Some(result) = returns { + // let result = result.as_ref(); + // let annotation = parse_ast_to_type_annotation_kinds( + // class_resolver, + // &temp_def_list, + // unifier, + // primitives, + // result, + // vec![(class_id, class_type_vars_def.clone())] + // .into_iter() + // .collect::>(), + // )?; + // // find type vars within this return type annotation + // let type_vars_within = + // get_type_var_contained_in_type_annotation(&annotation); + // // handle the class type var and the method type var + // for type_var_within in type_vars_within { + // let TypeAnnotation::TypeVar(ty) = type_var_within else { + // unreachable!("must be type var annotation"); + // }; + + // let id = Self::get_var_id(ty, unifier)?; + // if let Some(prev_ty) = method_var_map.insert(id, ty) { + // // if already in the list, make sure they are the same? + // assert_eq!(prev_ty, ty); + // } + // } + // let dummy_return_type = unifier.get_dummy_var().ty; + // type_var_to_concrete_def.insert(dummy_return_type, annotation.clone()); + // dummy_return_type + // } else { + // // if do not have return annotation, return none + // // for uniform handling, still use type annotation + // let dummy_return_type = unifier.get_dummy_var().ty; + // type_var_to_concrete_def.insert( + // dummy_return_type, + // TypeAnnotation::Primitive(primitives.none), + // ); + // dummy_return_type + // } + // }; + + // let TopLevelDef::Function { var_id, .. } = + // &mut *temp_def_list.get(method_id.0).unwrap().write() else { + // unreachable!() + // }; + // var_id.extend_from_slice(method_var_map + // .iter() + // .filter_map(|(id, ty)| { + // if matches!(&*unifier.get_ty(*ty), TypeEnum::TVar { range, .. } if range.is_empty()) { + // None + // } else { + // Some(*id) + // } + // }) + // .collect_vec() + // .as_slice() + // ); + // let method_type = unifier.add_ty(TypeEnum::TFunc(FunSignature { + // args: arg_types, + // ret: ret_type, + // vars: method_var_map, + // })); + + // // unify now since function type is not in type annotation define + // // which should be fine since type within method_type will be subst later + // unifier + // .unify(method_dummy_ty, method_type) + // .map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?; + + + } + _ => {} + } + } + + } + if !errors.is_empty() { + return Err(errors); + } + Ok(()) } + + fn analyze_single_class_methods_fields2( + class_def: &Arc>, + class_ast: &ast::StmtKind<()>, + temp_def_list: &[Arc>], + unifier: &mut Unifier, + primitives: &PrimitiveStore, + type_var_to_concrete_def: &mut HashMap, + core_info: (&HashSet, &ComposerConfig), + ) -> Result<(), HashSet> { + let (keyword_list, core_config) = core_info; + let mut class_def = class_def.write(); + let TopLevelDef::Class { + object_id, + ancestors, + fields, + attributes, + methods, + resolver, + type_vars, + .. + } = &mut *class_def + else { + unreachable!("here must be toplevel class def"); + }; + let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast else { + unreachable!("here must be class def ast") + }; + + + let ( + class_id, + _class_name, + _class_bases_ast, + class_body_ast, + _class_ancestor_def, + class_fields_def, + class_attributes_def, + class_methods_def, + class_type_vars_def, + class_resolver, + ) = ( + *object_id, *name, bases, body, ancestors, fields, attributes, methods, type_vars, + resolver, + ); + + let class_resolver = class_resolver.as_ref().unwrap(); + let class_resolver = class_resolver.as_ref(); + + let mut defined_fields: HashSet<_> = HashSet::new(); + for b in class_body_ast { + match &b.node { + ast::StmtKind::FunctionDef { args, returns, name, .. } => { + let (method_dummy_ty, method_id) = + Self::get_class_method_def_info(class_methods_def, *name)?; + + let mut method_var_map = VarMap::new(); + + let arg_types: Vec = { + // check method parameters cannot have same name + let mut defined_parameter_name: HashSet<_> = HashSet::new(); + let zelf: StrRef = "self".into(); + for x in &args.args { + if !defined_parameter_name.insert(x.node.arg) + || (keyword_list.contains(&x.node.arg) && x.node.arg != zelf) + { + return Err(HashSet::from([ + format!("top level function must have unique parameter names \ + and names should not be the same as the keywords (at {})", + x.location), + ])) + } + } + + if name == &"__init__".into() && !defined_parameter_name.contains(&zelf) { + return Err(HashSet::from([ + format!("__init__ method must have a `self` parameter (at {})", b.location), + ])) + } + if !defined_parameter_name.contains(&zelf) { + return Err(HashSet::from([ + format!("class method must have a `self` parameter (at {})", b.location), + ])) + } + + let mut result = Vec::new(); + + let arg_with_default: Vec<( + &ast::Located>, + Option<&ast::Expr>, + )> = args + .args + .iter() + .rev() + .zip( + args.defaults + .iter() + .rev() + .map(|x| -> Option<&ast::Expr> { Some(x) }) + .chain(std::iter::repeat(None)), + ) + .collect_vec(); + + for (x, default) in arg_with_default.into_iter().rev() { + let name = x.node.arg; + if name != zelf { + let type_ann = { + let annotation_expr = x + .node + .annotation + .as_ref() + .ok_or_else(|| HashSet::from([ + format!( + "type annotation needed for `{}` at {}", + x.node.arg, x.location + ), + ]))? + .as_ref(); + parse_ast_to_type_annotation_kinds( + class_resolver, + temp_def_list, + unifier, + primitives, + annotation_expr, + vec![(class_id, class_type_vars_def.clone())] + .into_iter() + .collect::>(), + )? + }; + // find type vars within this method parameter type annotation + let type_vars_within = + get_type_var_contained_in_type_annotation(&type_ann); + // handle the class type var and the method type var + for type_var_within in type_vars_within { + let TypeAnnotation::TypeVar(ty) = type_var_within else { + unreachable!("must be type var annotation") + }; + + let id = Self::get_var_id(ty, unifier)?; + if let Some(prev_ty) = method_var_map.insert(id, ty) { + // if already in the list, make sure they are the same? + assert_eq!(prev_ty, ty); + } + } + // finish handling type vars + let dummy_func_arg = FuncArg { + name, + ty: unifier.get_dummy_var().ty, + default_value: match default { + None => None, + Some(default) => { + if name == "self".into() { + return Err(HashSet::from([ + format!("`self` parameter cannot take default value (at {})", x.location), + ])); + } + Some({ + let v = Self::parse_parameter_default_value( + default, + class_resolver, + )?; + Self::check_default_param_type( + &v, &type_ann, primitives, unifier, + ) + .map_err(|err| HashSet::from([ + format!("{} (at {})", err, x.location), + ]))?; + v + }) + } + }, + }; + // push the dummy type and the type annotation + // into the list for later unification + type_var_to_concrete_def + .insert(dummy_func_arg.ty, type_ann.clone()); + result.push(dummy_func_arg); + } + } + result + }; + + let ret_type = { + if let Some(result) = returns { + let result = result.as_ref(); + let annotation = parse_ast_to_type_annotation_kinds( + class_resolver, + temp_def_list, + unifier, + primitives, + result, + vec![(class_id, class_type_vars_def.clone())] + .into_iter() + .collect::>(), + )?; + // find type vars within this return type annotation + let type_vars_within = + get_type_var_contained_in_type_annotation(&annotation); + // handle the class type var and the method type var + for type_var_within in type_vars_within { + let TypeAnnotation::TypeVar(ty) = type_var_within else { + unreachable!("must be type var annotation"); + }; + + let id = Self::get_var_id(ty, unifier)?; + if let Some(prev_ty) = method_var_map.insert(id, ty) { + // if already in the list, make sure they are the same? + assert_eq!(prev_ty, ty); + } + } + let dummy_return_type = unifier.get_dummy_var().ty; + type_var_to_concrete_def.insert(dummy_return_type, annotation.clone()); + dummy_return_type + } else { + // if do not have return annotation, return none + // for uniform handling, still use type annotation + let dummy_return_type = unifier.get_dummy_var().ty; + type_var_to_concrete_def.insert( + dummy_return_type, + TypeAnnotation::Primitive(primitives.none), + ); + dummy_return_type + } + }; + + let TopLevelDef::Function { var_id, .. } = + &mut *temp_def_list.get(method_id.0).unwrap().write() else { + unreachable!() + }; + var_id.extend_from_slice(method_var_map + .iter() + .filter_map(|(id, ty)| { + if matches!(&*unifier.get_ty(*ty), TypeEnum::TVar { range, .. } if range.is_empty()) { + None + } else { + Some(*id) + } + }) + .collect_vec() + .as_slice() + ); + let method_type = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: arg_types, + ret: ret_type, + vars: method_var_map, + })); + + // unify now since function type is not in type annotation define + // which should be fine since type within method_type will be subst later + unifier + .unify(method_dummy_ty, method_type) + .map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?; + } + ast::StmtKind::AnnAssign { target, annotation, value, .. } => { + if let ast::ExprKind::Name { id: attr, .. } = &target.node { + if defined_fields.insert(attr.to_string()) { + let dummy_field_type = unifier.get_dummy_var().ty; + + let annotation = match value { + None => { + // handle Kernel[T], KernelInvariant[T] + let (annotation, mutable) = match &annotation.node { + ast::ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into() + ) => + { + (slice, false) + } + ast::ExprKind::Subscript { value, slice, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into()) + ) => + { + (slice, true) + } + _ if core_config.kernel_ann.is_none() => (annotation, true), + _ => continue, // ignore fields annotated otherwise + }; + class_fields_def.push((*attr, dummy_field_type, mutable)); + annotation + } + // Supporting Class Attributes + Some(boxed_expr) => { + // Class attributes are set as immutable regardless + let (annotation, _) = match &annotation.node { + ast::ExprKind::Subscript { slice, .. } => (slice, false), + _ if core_config.kernel_ann.is_none() => (annotation, false), + _ => continue, + }; + + match &**boxed_expr { + ast::Located {location: _, custom: (), node: ast::ExprKind::Constant { value: v, kind: _ }} => { + // Restricting the types allowed to be defined as class attributes + match v { + ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {} + _ => { + return Err(HashSet::from([ + format!( + "unsupported statement in class definition body (at {})", + b.location + ), + ])) + } + } + class_attributes_def.push((*attr, dummy_field_type, v.clone())); + } + _ => { + return Err(HashSet::from([ + format!( + "unsupported statement in class definition body (at {})", + b.location + ), + ])) + } + } + annotation + } + }; + let parsed_annotation = parse_ast_to_type_annotation_kinds( + class_resolver, + temp_def_list, + unifier, + primitives, + annotation.as_ref(), + vec![(class_id, class_type_vars_def.clone())] + .into_iter() + .collect::>(), + )?; + // find type vars within this return type annotation + let type_vars_within = + get_type_var_contained_in_type_annotation(&parsed_annotation); + // handle the class type var and the method type var + for type_var_within in type_vars_within { + let TypeAnnotation::TypeVar(t) = type_var_within else { + unreachable!("must be type var annotation") + }; + + if !class_type_vars_def.contains(&t) { + return Err(HashSet::from([ + format!( + "class fields can only use type \ + vars over which the class is generic (at {})", + annotation.location + ), + ])) + } + } + type_var_to_concrete_def.insert(dummy_field_type, parsed_annotation); + } else { + return Err(HashSet::from([ + format!( + "same class fields `{}` defined twice (at {})", + attr, target.location + ), + ])) + } + } else { + return Err(HashSet::from([ + format!( + "unsupported statement type in class definition body (at {})", + target.location + ), + ])) + } + } + ast::StmtKind::Assign { .. } // we don't class attributes + | ast::StmtKind::Expr { value: _, .. } // typically a docstring; ignoring all expressions matches CPython behavior + | ast::StmtKind::Pass { .. } => {} + _ => { + return Err(HashSet::from([ + format!( + "unsupported statement in class definition body (at {})", + b.location + ), + ])) + } + } + } + Ok(()) + } + /// step 3, class fields and methods fn analyze_top_level_class_fields_methods(&mut self) -> Result<(), HashSet> { @@ -2048,4 +2586,316 @@ impl TopLevelComposer { } Ok(()) } + + /// step 1, analyze the type vars associated with top level class + fn analyze_top_level_class_type_var(&mut self) -> Result<(), HashSet> { + let def_list = &self.definition_ast_list; + let temp_def_list = self.extract_def_list(); + let unifier = self.unifier.borrow_mut(); + let primitives_store = &self.primitives_ty; + + let mut analyze = |class_def: &Arc>, class_ast: &Option| { + // only deal with class def here + let mut class_def = class_def.write(); + let (class_bases_ast, class_def_type_vars, class_resolver) = { + if let TopLevelDef::Class { type_vars, resolver, .. } = &mut *class_def { + let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) = + class_ast + else { + unreachable!() + }; + + (bases, type_vars, resolver) + } else { + return Ok(()); + } + }; + let class_resolver = class_resolver.as_ref().unwrap(); + let class_resolver = &**class_resolver; + + let mut is_generic = false; + for b in class_bases_ast { + match &b.node { + // analyze typevars bounded to the class, + // only support things like `class A(Generic[T, V])`, + // things like `class A(Generic[T, V, ImportedModule.T])` is not supported + // i.e. only simple names are allowed in the subscript + // should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params + ast::ExprKind::Subscript { value, slice, .. } + if { + matches!( + &value.node, + ast::ExprKind::Name { id, .. } if id == &"Generic".into() + ) + } => + { + if is_generic { + return Err(HashSet::from([format!( + "only single Generic[...] is allowed (at {})", + b.location + )])); + } + is_generic = true; + + let type_var_list: Vec<&ast::Expr<()>>; + // if `class A(Generic[T, V, G])` + if let ast::ExprKind::Tuple { elts, .. } = &slice.node { + type_var_list = elts.iter().collect_vec(); + // `class A(Generic[T])` + } else { + type_var_list = vec![&**slice]; + } + + // parse the type vars + let type_vars = type_var_list + .into_iter() + .map(|e| { + class_resolver.parse_type_annotation( + &temp_def_list, + unifier, + primitives_store, + e, + ) + }) + .collect::, _>>()?; + + // check if all are unique type vars + let all_unique_type_var = { + let mut occurred_type_var_id: HashSet = HashSet::new(); + type_vars.iter().all(|x| { + let ty = unifier.get_ty(*x); + if let TypeEnum::TVar { id, .. } = ty.as_ref() { + occurred_type_var_id.insert(*id) + } else { + false + } + }) + }; + if !all_unique_type_var { + return Err(HashSet::from([format!( + "duplicate type variable occurs (at {})", + slice.location + )])); + } + + // add to TopLevelDef + class_def_type_vars.extend(type_vars); + } + + // if others, do nothing in this function + _ => continue, + } + } + Ok(()) + }; + let mut errors = HashSet::new(); + for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) { + if class_ast.is_none() { + continue; + } + if let Err(e) = analyze(class_def, class_ast) { + errors.extend(e); + } + } + if !errors.is_empty() { + return Err(errors); + } + Ok(()) + } + + /// step 2, base classes. + /// now that the type vars of all classes are done, handle base classes and + /// put Self class into the ancestors list. We only allow single inheritance + fn analyze_top_level_class_bases(&mut self) -> Result<(), HashSet> { + if self.unifier.top_level.is_none() { + let ctx = Arc::new(self.make_top_level_context()); + self.unifier.top_level = Some(ctx); + } + + let temp_def_list = self.extract_def_list(); + let unifier = self.unifier.borrow_mut(); + let primitive_types = self.primitives_ty; + + let mut get_direct_parents = + |class_def: &Arc>, class_ast: &Option| { + let mut class_def = class_def.write(); + let (class_def_id, class_bases, class_ancestors, class_resolver, class_type_vars) = { + if let TopLevelDef::Class { + ancestors, resolver, object_id, type_vars, .. + } = &mut *class_def + { + let Some(ast::Located { + node: ast::StmtKind::ClassDef { bases, .. }, .. + }) = class_ast + else { + unreachable!() + }; + + (object_id, bases, ancestors, resolver, type_vars) + } else { + return Ok(()); + } + }; + let class_resolver = class_resolver.as_ref().unwrap(); + let class_resolver = &**class_resolver; + + let mut has_base = false; + for b in class_bases { + // type vars have already been handled, so skip on `Generic[...]` + if matches!( + &b.node, + ast::ExprKind::Subscript { value, .. } + if matches!( + &value.node, + ast::ExprKind::Name { id, .. } if id == &"Generic".into() + ) + ) { + continue; + } + + if has_base { + return Err(HashSet::from([format!( + "a class definition can only have at most one base class \ + declaration and one generic declaration (at {})", + b.location + )])); + } + has_base = true; + + // the function parse_ast_to make sure that no type var occurred in + // bast_ty if it is a CustomClassKind + let base_ty = parse_ast_to_type_annotation_kinds( + class_resolver, + &temp_def_list, + unifier, + &primitive_types, + b, + vec![(*class_def_id, class_type_vars.clone())] + .into_iter() + .collect::>(), + )?; + + if let TypeAnnotation::CustomClass { .. } = &base_ty { + class_ancestors.push(base_ty); + } else { + return Err(HashSet::from([format!( + "class base declaration can only be custom class (at {})", + b.location, + )])); + } + } + Ok(()) + }; + + // first, only push direct parent into the list + let mut errors = HashSet::new(); + for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) { + if class_ast.is_none() { + continue; + } + if let Err(e) = get_direct_parents(class_def, class_ast) { + errors.extend(e); + } + } + if !errors.is_empty() { + return Err(errors); + } + + // second, get all ancestors + let mut ancestors_store: HashMap> = HashMap::default(); + let mut get_all_ancestors = + |class_def: &Arc>| -> Result<(), HashSet> { + let class_def = class_def.read(); + let (class_ancestors, class_id) = { + if let TopLevelDef::Class { ancestors, object_id, .. } = &*class_def { + (ancestors, *object_id) + } else { + return Ok(()); + } + }; + ancestors_store.insert( + class_id, + // if class has direct parents, get all ancestors of its parents. Else just empty + if class_ancestors.is_empty() { + vec![] + } else { + Self::get_all_ancestors_helper( + &class_ancestors[0], + temp_def_list.as_slice(), + )? + }, + ); + Ok(()) + }; + for (class_def, ast) in self.definition_ast_list.iter().skip(self.builtin_num) { + if ast.is_none() { + continue; + } + if let Err(e) = get_all_ancestors(class_def) { + errors.extend(e); + } + } + if !errors.is_empty() { + return Err(errors); + } + + // insert the ancestors to the def list + for (class_def, class_ast) in self.definition_ast_list.iter_mut().skip(self.builtin_num) { + if class_ast.is_none() { + continue; + } + let mut class_def = class_def.write(); + let (class_ancestors, class_id, class_type_vars) = { + if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = &mut *class_def + { + (ancestors, *object_id, type_vars) + } else { + continue; + } + }; + + let ans = ancestors_store.get_mut(&class_id).unwrap(); + class_ancestors.append(ans); + + // insert self type annotation to the front of the vector to maintain the order + class_ancestors + .insert(0, make_self_type_annotation(class_type_vars.as_slice(), class_id)); + + // special case classes that inherit from Exception + if class_ancestors + .iter() + .any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7)) + { + // if inherited from Exception, the body should be a pass + let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node else { + unreachable!() + }; + + for stmt in body { + if matches!( + stmt.node, + ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. } + ) { + return Err(HashSet::from([ + "Classes inherited from exception should have no custom fields/methods" + .into(), + ])); + } + } + } + } + + // deal with ancestor of Exception object + let TopLevelDef::Class { name, ancestors, object_id, .. } = + &mut *self.definition_ast_list[7].0.write() + else { + unreachable!() + }; + + assert_eq!(*name, "Exception".into()); + ancestors.push(make_self_type_annotation(&[], *object_id)); + + Ok(()) + } + }