diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 62a54473..ab5e4d80 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -485,7 +485,7 @@ impl TopLevelComposer { // things like `class A(Generic[T, V, ImportedModule.T])` is not supported // i.e. only simple names are allowed in the subscript // should update the TopLevelDef::Class.typevars and the TypeEnum::TObj.params - ast::ExprKind::Subscript { value, slice, .. } + ExprKind::Subscript { value, slice, .. } if { matches!( &value.node, @@ -501,9 +501,9 @@ impl TopLevelComposer { } is_generic = true; - let type_var_list: Vec<&ast::Expr<()>>; + let type_var_list: Vec<&Expr<()>>; // if `class A(Generic[T, V, G])` - if let ast::ExprKind::Tuple { elts, .. } = &slice.node { + if let ExprKind::Tuple { elts, .. } = &slice.node { type_var_list = elts.iter().collect_vec(); // `class A(Generic[T])` } else { @@ -1014,18 +1014,18 @@ impl TopLevelComposer { } } - let arg_with_default: Vec<(&ast::Located>, Option<&ast::Expr>)> = - args.args - .iter() - .rev() - .zip( - args.defaults - .iter() - .rev() - .map(|x| -> Option<&ast::Expr> { Some(x) }) - .chain(std::iter::repeat(None)), - ) - .collect_vec(); + let arg_with_default: Vec<(&ast::Located>, Option<&Expr>)> = args + .args + .iter() + .rev() + .zip( + args.defaults + .iter() + .rev() + .map(|x| -> Option<&Expr> { Some(x) }) + .chain(std::iter::repeat(None)), + ) + .collect_vec(); arg_with_default .iter() @@ -1283,7 +1283,7 @@ impl TopLevelComposer { let arg_with_default: Vec<( &ast::Located>, - Option<&ast::Expr>, + Option<&Expr>, )> = args .args .iter() @@ -1292,7 +1292,7 @@ impl TopLevelComposer { args.defaults .iter() .rev() - .map(|x| -> Option<&ast::Expr> { Some(x) }) + .map(|x| -> Option<&Expr> { Some(x) }) .chain(std::iter::repeat(None)), ) .collect_vec(); @@ -1449,7 +1449,7 @@ impl TopLevelComposer { .map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?; } ast::StmtKind::AnnAssign { target, annotation, value, .. } => { - if let ast::ExprKind::Name { id: attr, .. } = &target.node { + if let ExprKind::Name { id: attr, .. } = &target.node { if defined_fields.insert(attr.to_string()) { let dummy_field_type = unifier.get_dummy_var().ty; @@ -1457,7 +1457,7 @@ impl TopLevelComposer { None => { // handle Kernel[T], KernelInvariant[T] let (annotation, mutable) = match &annotation.node { - ast::ExprKind::Subscript { value, slice, .. } + ExprKind::Subscript { value, slice, .. } if matches!( &value.node, ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into() @@ -1465,7 +1465,7 @@ impl TopLevelComposer { { (slice, false) } - ast::ExprKind::Subscript { value, slice, .. } + ExprKind::Subscript { value, slice, .. } if matches!( &value.node, ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into()) @@ -1483,13 +1483,13 @@ impl TopLevelComposer { Some(boxed_expr) => { // Class attributes are set as immutable regardless let (annotation, _) = match &annotation.node { - ast::ExprKind::Subscript { slice, .. } => (slice, false), + ExprKind::Subscript { slice, .. } => (slice, false), _ if core_config.kernel_ann.is_none() => (annotation, false), _ => continue, }; match &**boxed_expr { - ast::Located {location: _, custom: (), node: ast::ExprKind::Constant { value: v, kind: _ }} => { + ast::Located {location: _, custom: (), node: ExprKind::Constant { value: v, kind: _ }} => { // Restricting the types allowed to be defined as class attributes match v { ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {} @@ -1937,284 +1937,296 @@ impl TopLevelComposer { if ast.is_none() { return Ok(()); } - let mut function_def = def.write(); - if let TopLevelDef::Function { - instance_to_stmt, - instance_to_symbol, - name, - simple_name, - signature, - resolver, - .. - } = &mut *function_def - { - let signature_ty_enum = unifier.get_ty(*signature); - let TypeEnum::TFunc(FunSignature { args, ret, vars, .. }) = - signature_ty_enum.as_ref() + + let (name, simple_name, signature, resolver) = { + let function_def = def.read(); + let TopLevelDef::Function { name, simple_name, signature, resolver, .. } = + &*function_def else { - unreachable!("must be typeenum::tfunc") + return Ok(()); }; - let mut vars = vars.clone(); - // None if is not class method - let uninst_self_type = { - if let Some(class_id) = method_class.get(&DefinitionId(id)) { - let class_def = definition_ast_list.get(class_id.0).unwrap(); - let class_def = class_def.0.read(); - let TopLevelDef::Class { type_vars, .. } = &*class_def else { - unreachable!("must be class def") + (name.clone(), *simple_name, *signature, resolver.clone()) + }; + + let signature_ty_enum = unifier.get_ty(signature); + let TypeEnum::TFunc(FunSignature { args, ret, vars, .. }) = signature_ty_enum.as_ref() + else { + unreachable!("must be typeenum::tfunc") + }; + + let mut vars = vars.clone(); + // None if is not class method + let uninst_self_type = { + if let Some(class_id) = method_class.get(&DefinitionId(id)) { + let class_def = definition_ast_list.get(class_id.0).unwrap(); + let class_def = class_def.0.read(); + let TopLevelDef::Class { type_vars, .. } = &*class_def else { + unreachable!("must be class def") + }; + + let ty_ann = make_self_type_annotation(type_vars, *class_id); + let self_ty = get_type_from_type_annotation_kinds( + &def_list, + unifier, + primitives_ty, + &ty_ann, + &mut None, + )?; + vars.extend(type_vars.iter().map(|ty| { + let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else { + unreachable!() }; - let ty_ann = make_self_type_annotation(type_vars, *class_id); - let self_ty = get_type_from_type_annotation_kinds( - &def_list, - unifier, - primitives_ty, - &ty_ann, - &mut None, - )?; - vars.extend(type_vars.iter().map(|ty| { - let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else { + (*id, *ty) + })); + Some((self_ty, type_vars.clone())) + } else { + None + } + }; + // carefully handle those with bounds, without bounds and no typevars + // if class methods, `vars` also contains all class typevars here + let (type_var_subst_comb, no_range_vars) = { + let mut no_ranges: Vec = Vec::new(); + let var_combs = vars + .values() + .map(|ty| { + unifier.get_instantiations(*ty).unwrap_or_else(|| { + let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = + &*unifier.get_ty(*ty) + else { unreachable!() }; - (*id, *ty) - })); - Some((self_ty, type_vars.clone())) - } else { - None - } - }; - // carefully handle those with bounds, without bounds and no typevars - // if class methods, `vars` also contains all class typevars here - let (type_var_subst_comb, no_range_vars) = { - let mut no_ranges: Vec = Vec::new(); - let var_combs = vars - .values() - .map(|ty| { - unifier.get_instantiations(*ty).unwrap_or_else(|| { - let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = - &*unifier.get_ty(*ty) - else { - unreachable!() - }; - - let rigid = unifier.get_fresh_rigid_var(*name, *loc).ty; - no_ranges.push(rigid); - vec![rigid] - }) + let rigid = unifier.get_fresh_rigid_var(*name, *loc).ty; + no_ranges.push(rigid); + vec![rigid] }) - .multi_cartesian_product() - .collect_vec(); - let mut result: Vec = Vec::default(); - for comb in var_combs { - result.push(vars.keys().copied().zip(comb).collect()); + }) + .multi_cartesian_product() + .collect_vec(); + let mut result: Vec = Vec::default(); + for comb in var_combs { + result.push(vars.keys().copied().zip(comb).collect()); + } + // NOTE: if is empty, means no type var, append a empty subst, ok to do this? + if result.is_empty() { + result.push(VarMap::new()); + } + (result, no_ranges) + }; + + for subst in type_var_subst_comb { + // for each instance + let inst_ret = unifier.subst(*ret, &subst).unwrap_or(*ret); + let inst_args = { + args.iter() + .map(|a| FuncArg { + name: a.name, + ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), + default_value: a.default_value.clone(), + is_vararg: false, + }) + .collect_vec() + }; + let self_type = { + uninst_self_type.clone().map(|(self_type, type_vars)| { + let subst_for_self = { + let class_ty_var_ids = type_vars + .iter() + .map(|x| { + if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { + *id + } else { + unreachable!("must be type var here"); + } + }) + .collect::>(); + subst + .iter() + .filter_map(|(ty_var_id, ty_var_target)| { + if class_ty_var_ids.contains(ty_var_id) { + Some((*ty_var_id, *ty_var_target)) + } else { + None + } + }) + .collect::() + }; + unifier.subst(self_type, &subst_for_self).unwrap_or(self_type) + }) + }; + let mut identifiers = { + let mut result = HashMap::new(); + if self_type.is_some() { + result.insert("self".into(), IdentifierInfo::default()); } - // NOTE: if is empty, means no type var, append a empty subst, ok to do this? - if result.is_empty() { - result.push(VarMap::new()); - } - (result, no_ranges) + result.extend(inst_args.iter().map(|x| (x.name, IdentifierInfo::default()))); + result + }; + let mut calls: HashMap = HashMap::new(); + let mut inferencer = Inferencer { + top_level: ctx.as_ref(), + defined_identifiers: identifiers.clone(), + function_data: &mut FunctionData { + resolver: resolver.as_ref().unwrap().clone(), + return_type: if unifier.unioned(inst_ret, primitives_ty.none) { + None + } else { + Some(inst_ret) + }, + // NOTE: allowed type vars + bound_variables: no_range_vars.clone(), + }, + unifier, + variable_mapping: { + let mut result: HashMap = HashMap::new(); + if let Some(self_ty) = self_type { + result.insert("self".into(), self_ty); + } + result.extend(inst_args.iter().map(|x| (x.name, x.ty))); + result + }, + primitives: primitives_ty, + virtual_checks: &mut Vec::new(), + calls: &mut calls, + in_handler: false, }; - for subst in type_var_subst_comb { - // for each instance - let inst_ret = unifier.subst(*ret, &subst).unwrap_or(*ret); - let inst_args = { - args.iter() - .map(|a| FuncArg { - name: a.name, - ty: unifier.subst(a.ty, &subst).unwrap_or(a.ty), - default_value: a.default_value.clone(), - is_vararg: false, - }) - .collect_vec() - }; - let self_type = { - uninst_self_type.clone().map(|(self_type, type_vars)| { - let subst_for_self = { - let class_ty_var_ids = type_vars - .iter() - .map(|x| { - if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) { - *id - } else { - unreachable!("must be type var here"); - } - }) - .collect::>(); - subst - .iter() - .filter_map(|(ty_var_id, ty_var_target)| { - if class_ty_var_ids.contains(ty_var_id) { - Some((*ty_var_id, *ty_var_target)) - } else { - None - } - }) - .collect::() + let ast::StmtKind::FunctionDef { body, decorator_list, .. } = + ast.clone().unwrap().node + else { + unreachable!("must be function def ast") + }; + + if !decorator_list.is_empty() { + if matches!(&decorator_list[0].node, ExprKind::Name { id, .. } if id == &"extern".into()) + { + let TopLevelDef::Function { instance_to_symbol, .. } = &mut *def.write() + else { + unreachable!() + }; + instance_to_symbol.insert(String::new(), simple_name.to_string()); + continue; + } + + if matches!(&decorator_list[0].node, ExprKind::Name { id, .. } if id == &"rpc".into()) + { + let TopLevelDef::Function { instance_to_symbol, .. } = &mut *def.write() + else { + unreachable!() + }; + instance_to_symbol.insert(String::new(), simple_name.to_string()); + continue; + } + + if let ExprKind::Call { func, .. } = &decorator_list[0].node { + if matches!(&func.node, ExprKind::Name { id, .. } if id == &"rpc".into()) { + let TopLevelDef::Function { instance_to_symbol, .. } = + &mut *def.write() + else { + unreachable!() }; - unifier.subst(self_type, &subst_for_self).unwrap_or(self_type) - }) - }; - let mut identifiers = { - let mut result = HashMap::new(); - if self_type.is_some() { - result.insert("self".into(), IdentifierInfo::default()); - } - result - .extend(inst_args.iter().map(|x| (x.name, IdentifierInfo::default()))); - result - }; - let mut calls: HashMap = HashMap::new(); - let mut inferencer = Inferencer { - top_level: ctx.as_ref(), - defined_identifiers: identifiers.clone(), - function_data: &mut FunctionData { - resolver: resolver.as_ref().unwrap().clone(), - return_type: if unifier.unioned(inst_ret, primitives_ty.none) { - None - } else { - Some(inst_ret) - }, - // NOTE: allowed type vars - bound_variables: no_range_vars.clone(), - }, - unifier, - variable_mapping: { - let mut result: HashMap = HashMap::new(); - if let Some(self_ty) = self_type { - result.insert("self".into(), self_ty); - } - result.extend(inst_args.iter().map(|x| (x.name, x.ty))); - result - }, - primitives: primitives_ty, - virtual_checks: &mut Vec::new(), - calls: &mut calls, - in_handler: false, - }; - - let ast::StmtKind::FunctionDef { body, decorator_list, .. } = - ast.clone().unwrap().node - else { - unreachable!("must be function def ast") - }; - if !decorator_list.is_empty() - && matches!(&decorator_list[0].node, - ast::ExprKind::Name{ id, .. } if id == &"extern".into()) - { - instance_to_symbol.insert(String::new(), simple_name.to_string()); - continue; - } - if !decorator_list.is_empty() - && matches!(&decorator_list[0].node, - ast::ExprKind::Name{ id, .. } if id == &"rpc".into()) - { - instance_to_symbol.insert(String::new(), simple_name.to_string()); - continue; - } - if !decorator_list.is_empty() { - if let ast::ExprKind::Call { func, .. } = &decorator_list[0].node { - if matches!(&func.node, - ast::ExprKind::Name{ id, .. } if id == &"rpc".into()) - { - instance_to_symbol.insert(String::new(), simple_name.to_string()); - continue; - } + instance_to_symbol.insert(String::new(), simple_name.to_string()); + continue; } } + } - let fun_body = body - .into_iter() + let fun_body = + body.into_iter() .map(|b| inferencer.fold_stmt(b)) .collect::, _>>()?; - let returned = inferencer.check_block(fun_body.as_slice(), &mut identifiers)?; - { - // check virtuals - let defs = ctx.definitions.read(); - for (subtype, base, loc) in &*inferencer.virtual_checks { - let base_id = { - let base = inferencer.unifier.get_ty(*base); - if let TypeEnum::TObj { obj_id, .. } = &*base { - *obj_id - } else { - return Err(HashSet::from([format!( - "Base type should be a class (at {loc})" - )])); - } - }; - let subtype_id = { - let ty = inferencer.unifier.get_ty(*subtype); - if let TypeEnum::TObj { obj_id, .. } = &*ty { - *obj_id - } else { - let base_repr = inferencer.unifier.stringify(*base); - let subtype_repr = inferencer.unifier.stringify(*subtype); - return Err(HashSet::from([format!( - "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"), - ])); - } - }; - let subtype_entry = defs[subtype_id.0].read(); - let TopLevelDef::Class { ancestors, .. } = &*subtype_entry else { - unreachable!() - }; - - let m = ancestors.iter() - .find(|kind| matches!(kind, TypeAnnotation::CustomClass { id, .. } if *id == base_id)); - if m.is_none() { + let returned = inferencer.check_block(fun_body.as_slice(), &mut identifiers)?; + { + // check virtuals + let defs = ctx.definitions.read(); + for (subtype, base, loc) in &*inferencer.virtual_checks { + let base_id = { + let base = inferencer.unifier.get_ty(*base); + if let TypeEnum::TObj { obj_id, .. } = &*base { + *obj_id + } else { + return Err(HashSet::from([format!( + "Base type should be a class (at {loc})" + )])); + } + }; + let subtype_id = { + let ty = inferencer.unifier.get_ty(*subtype); + if let TypeEnum::TObj { obj_id, .. } = &*ty { + *obj_id + } else { let base_repr = inferencer.unifier.stringify(*base); let subtype_repr = inferencer.unifier.stringify(*subtype); return Err(HashSet::from([format!( "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"), ])); } + }; + let subtype_entry = defs[subtype_id.0].read(); + let TopLevelDef::Class { ancestors, .. } = &*subtype_entry else { + unreachable!() + }; + + let m = ancestors.iter() + .find(|kind| matches!(kind, TypeAnnotation::CustomClass { id, .. } if *id == base_id)); + if m.is_none() { + let base_repr = inferencer.unifier.stringify(*base); + let subtype_repr = inferencer.unifier.stringify(*subtype); + return Err(HashSet::from([format!( + "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"), + ])); } } - if !unifier.unioned(inst_ret, primitives_ty.none) && !returned { - let def_ast_list = &definition_ast_list; - let ret_str = unifier.internal_stringify( - inst_ret, - &mut |id| { - let TopLevelDef::Class { name, .. } = &*def_ast_list[id].0.read() - else { - unreachable!("must be class id here") - }; - - name.to_string() - }, - &mut |id| format!("typevar{id}"), - &mut None, - ); - return Err(HashSet::from([format!( - "expected return type of `{}` in function `{}` (at {})", - ret_str, - name, - ast.as_ref().unwrap().location - )])); - } - - instance_to_stmt.insert( - get_subst_key( - unifier, - self_type, - &subst, - Some(&vars.keys().copied().collect()), - ), - FunInstance { - body: Arc::new(fun_body), - unifier_id: 0, - calls: Arc::new(calls), - subst, - }, - ); } + if !unifier.unioned(inst_ret, primitives_ty.none) && !returned { + let def_ast_list = &definition_ast_list; + let ret_str = unifier.internal_stringify( + inst_ret, + &mut |id| { + let TopLevelDef::Class { name, .. } = &*def_ast_list[id].0.read() + else { + unreachable!("must be class id here") + }; + + name.to_string() + }, + &mut |id| format!("typevar{id}"), + &mut None, + ); + return Err(HashSet::from([format!( + "expected return type of `{}` in function `{}` (at {})", + ret_str, + name, + ast.as_ref().unwrap().location + )])); + } + + let TopLevelDef::Function { instance_to_stmt, .. } = &mut *def.write() else { + unreachable!() + }; + instance_to_stmt.insert( + get_subst_key( + unifier, + self_type, + &subst, + Some(&vars.keys().copied().collect()), + ), + FunInstance { + body: Arc::new(fun_body), + unifier_id: 0, + calls: Arc::new(calls), + subst, + }, + ); } Ok(()) }; + for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.builtin_num) { if ast.is_none() { continue;