use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; use nac3parser::ast::{Constant, Location}; use super::*; impl TopLevelDef { pub fn to_string(&self, unifier: &mut Unifier) -> String { match self { TopLevelDef::Class { name, ancestors, fields, methods, type_vars, .. } => { let fields_str = fields .iter() .map(|(n, ty, _)| (n.to_string(), unifier.stringify(*ty))) .collect_vec(); let methods_str = methods .iter() .map(|(n, ty, id)| (n.to_string(), unifier.stringify(*ty), *id)) .collect_vec(); format!( "Class {{\nname: {:?},\nancestors: {:?},\nfields: {:?},\nmethods: {:?},\ntype_vars: {:?}\n}}", name, ancestors.iter().map(|ancestor| ancestor.stringify(unifier)).collect_vec(), fields_str.iter().map(|(a, _)| a).collect_vec(), methods_str.iter().map(|(a, b, _)| (a, b)).collect_vec(), type_vars.iter().map(|id| unifier.stringify(*id)).collect_vec(), ) } TopLevelDef::Function { name, signature, var_id, .. } => format!( "Function {{\nname: {:?},\nsig: {:?},\nvar_id: {:?}\n}}", name, unifier.stringify(*signature), { // preserve the order for debug output and test let mut r = var_id.clone(); r.sort_unstable(); r } ), } } } impl TopLevelComposer { pub fn make_primitives() -> (PrimitiveStore, Unifier) { let mut unifier = Unifier::new(); let int32 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(0), fields: HashMap::new(), params: HashMap::new(), }); let int64 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(1), fields: HashMap::new(), params: HashMap::new(), }); let float = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(2), fields: HashMap::new(), params: HashMap::new(), }); let bool = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(3), fields: HashMap::new(), params: HashMap::new(), }); let none = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(4), fields: HashMap::new(), params: HashMap::new(), }); let range = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(5), fields: HashMap::new(), params: HashMap::new(), }); let str = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(6), fields: HashMap::new(), params: HashMap::new(), }); let exception = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(7), fields: vec![ ("__name__".into(), (int32, true)), ("__file__".into(), (int32, true)), ("__line__".into(), (int32, true)), ("__col__".into(), (int32, true)), ("__func__".into(), (str, true)), ("__message__".into(), (str, true)), ("__param0__".into(), (int64, true)), ("__param1__".into(), (int64, true)), ("__param2__".into(), (int64, true)), ] .into_iter() .collect::>(), params: HashMap::new(), }); let uint32 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(8), fields: HashMap::new(), params: HashMap::new(), }); let uint64 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(9), fields: HashMap::new(), params: HashMap::new(), }); let option_type_var = unifier.get_fresh_var(Some("option_type_var".into()), None); let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], ret: bool, vars: HashMap::from([(option_type_var.1, option_type_var.0)]), })); let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![], ret: option_type_var.0, vars: HashMap::from([(option_type_var.1, option_type_var.0)]), })); let option = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(10), fields: vec![ ("is_some".into(), (is_some_type_fun_ty, true)), ("is_none".into(), (is_some_type_fun_ty, true)), ("unwrap".into(), (unwrap_fun_ty, true)), ] .into_iter() .collect::>(), params: HashMap::from([(option_type_var.1, option_type_var.0)]), }); let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception, uint32, uint64, option, }; crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); (primitives, unifier) } /// already include the definition_id of itself inside the ancestors vector /// when first regitering, the type_vars, fields, methods, ancestors are invalid pub fn make_top_level_class_def( index: usize, resolver: Option>, name: StrRef, constructor: Option, loc: Option, ) -> TopLevelDef { TopLevelDef::Class { name, object_id: DefinitionId(index), type_vars: Default::default(), fields: Default::default(), methods: Default::default(), ancestors: Default::default(), constructor, resolver, loc, } } /// when first registering, the type is a invalid value pub fn make_top_level_function_def( name: String, simple_name: StrRef, ty: Type, resolver: Option>, loc: Option, ) -> TopLevelDef { TopLevelDef::Function { name, simple_name, signature: ty, var_id: Default::default(), instance_to_symbol: Default::default(), instance_to_stmt: Default::default(), resolver, codegen_callback: None, loc, } } pub fn make_class_method_name(mut class_name: String, method_name: &str) -> String { class_name.push('.'); class_name.push_str(method_name); class_name } pub fn get_class_method_def_info( class_methods_def: &[(StrRef, Type, DefinitionId)], method_name: StrRef, ) -> Result<(Type, DefinitionId), String> { for (name, ty, def_id) in class_methods_def { if name == &method_name { return Ok((*ty, *def_id)); } } Err(format!("no method {} in the current class", method_name)) } /// get all base class def id of a class, excluding itself. \ /// this function should called only after the direct parent is set /// and before all the ancestors are set /// and when we allow single inheritance \ /// the order of the returned list is from the child to the deepest ancestor pub fn get_all_ancestors_helper( child: &TypeAnnotation, temp_def_list: &[Arc>], ) -> Result, String> { let mut result: Vec = Vec::new(); let mut parent = Self::get_parent(child, temp_def_list); while let Some(p) = parent { parent = Self::get_parent(&p, temp_def_list); let p_id = if let TypeAnnotation::CustomClass { id, .. } = &p { *id } else { unreachable!("must be class kind annotation") }; // check cycle let no_cycle = result.iter().all(|x| { if let TypeAnnotation::CustomClass { id, .. } = x { id.0 != p_id.0 } else { unreachable!("must be class kind annotation") } }); if no_cycle { result.push(p); } else { return Err("cyclic inheritance detected".into()); } } Ok(result) } /// should only be called when finding all ancestors, so panic when wrong fn get_parent( child: &TypeAnnotation, temp_def_list: &[Arc>], ) -> Option { let child_id = if let TypeAnnotation::CustomClass { id, .. } = child { *id } else { unreachable!("should be class type annotation") }; let child_def = temp_def_list.get(child_id.0).unwrap(); let child_def = child_def.read(); if let TopLevelDef::Class { ancestors, .. } = &*child_def { if !ancestors.is_empty() { Some(ancestors[0].clone()) } else { None } } else { unreachable!("child must be top level class def") } } /// get the var_id of a given TVar type pub fn get_var_id(var_ty: Type, unifier: &mut Unifier) -> Result { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { Ok(*id) } else { Err("not type var".to_string()) } } pub fn check_overload_function_type( this: Type, other: Type, unifier: &mut Unifier, type_var_to_concrete_def: &HashMap, ) -> bool { let this = unifier.get_ty(this); let this = this.as_ref(); let other = unifier.get_ty(other); let other = other.as_ref(); if let ( TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }), TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }), ) = (this, other) { // check args let args_ok = this_args .iter() .map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap())) .zip(other_args.iter().map(|FuncArg { name, ty, .. }| { (name, type_var_to_concrete_def.get(ty).unwrap()) })) .all(|(this, other)| { if this.0 == &"self".into() && this.0 == other.0 { true } else { this.0 == other.0 && check_overload_type_annotation_compatible(this.1, other.1, unifier) } }); // check rets let ret_ok = check_overload_type_annotation_compatible( type_var_to_concrete_def.get(this_ret).unwrap(), type_var_to_concrete_def.get(other_ret).unwrap(), unifier, ); // return args_ok && ret_ok } else { unreachable!("this function must be called with function type") } } pub fn check_overload_field_type( this: Type, other: Type, unifier: &mut Unifier, type_var_to_concrete_def: &HashMap, ) -> bool { check_overload_type_annotation_compatible( type_var_to_concrete_def.get(&this).unwrap(), type_var_to_concrete_def.get(&other).unwrap(), unifier, ) } pub fn get_all_assigned_field(stmts: &[ast::Stmt<()>]) -> Result, String> { let mut result = HashSet::new(); for s in stmts { match &s.node { ast::StmtKind::AnnAssign { target, .. } if { if let ast::ExprKind::Attribute { value, .. } = &target.node { if let ast::ExprKind::Name { id, .. } = &value.node { id == &"self".into() } else { false } } else { false } } => { return Err(format!( "redundant type annotation for class fields at {}", s.location )) } ast::StmtKind::Assign { targets, .. } => { for t in targets { if let ast::ExprKind::Attribute { value, attr, .. } = &t.node { if let ast::ExprKind::Name { id, .. } = &value.node { if id == &"self".into() { result.insert(*attr); } } } } } // TODO: do not check for For and While? ast::StmtKind::For { body, orelse, .. } | ast::StmtKind::While { body, orelse, .. } => { result.extend(Self::get_all_assigned_field(body.as_slice())?); result.extend(Self::get_all_assigned_field(orelse.as_slice())?); } ast::StmtKind::If { body, orelse, .. } => { let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) .cloned() .collect::>(); result.extend(inited_for_sure); } ast::StmtKind::Try { body, orelse, finalbody, .. } => { let inited_for_sure = Self::get_all_assigned_field(body.as_slice())? .intersection(&Self::get_all_assigned_field(orelse.as_slice())?) .cloned() .collect::>(); result.extend(inited_for_sure); result.extend(Self::get_all_assigned_field(finalbody.as_slice())?); } ast::StmtKind::With { body, .. } => { result.extend(Self::get_all_assigned_field(body.as_slice())?); } ast::StmtKind::Pass { .. } => {} ast::StmtKind::Assert { .. } => {} ast::StmtKind::Expr { .. } => {} _ => { unimplemented!() } } } Ok(result) } pub fn parse_parameter_default_value( default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync), ) -> Result { parse_parameter_default_value(default, resolver) } pub fn check_default_param_type( val: &SymbolValue, ty: &TypeAnnotation, primitive: &PrimitiveStore, unifier: &mut Unifier, ) -> Result<(), String> { let res = match val { SymbolValue::Bool(..) => { if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.bool) { None } else { Some("bool".to_string()) } } SymbolValue::Double(..) => { if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.float) { None } else { Some("float".to_string()) } } SymbolValue::I32(..) => { if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.int32) { None } else { Some("int32".to_string()) } } SymbolValue::I64(..) => { if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.int64) { None } else { Some("int64".to_string()) } } SymbolValue::U32(..) => { if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.uint32) { None } else { Some("uint32".to_string()) } } SymbolValue::U64(..) => { if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.uint64) { None } else { Some("uint64".to_string()) } } SymbolValue::Str(..) => { if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.str) { None } else { Some("str".to_string()) } } SymbolValue::Tuple(elts) => { if let TypeAnnotation::Tuple(elts_ty) = ty { for (e, t) in elts.iter().zip(elts_ty.iter()) { Self::check_default_param_type(e, t, primitive, unifier)? } if elts.len() != elts_ty.len() { Some(format!("tuple of length {}", elts.len())) } else { None } } else { Some("tuple".to_string()) } } }; if let Some(found) = res { Err(format!( "incompatible default parameter type, expect {}, found {}", ty.stringify(unifier), found )) } else { Ok(()) } } } pub fn parse_parameter_default_value( default: &ast::Expr, resolver: &(dyn SymbolResolver + Send + Sync), ) -> Result { fn handle_constant(val: &Constant, loc: &Location) -> Result { match val { Constant::Int(v) => { if let Ok(v) = (*v).try_into() { Ok(SymbolValue::I32(v)) } else { Err(format!("integer value out of range at {}", loc)) } } Constant::Float(v) => Ok(SymbolValue::Double(*v)), Constant::Bool(v) => Ok(SymbolValue::Bool(*v)), Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( tuple.iter().map(|x| handle_constant(x, loc)).collect::, _>>()?, )), _ => unimplemented!("this constant is not supported at {}", loc), } } match &default.node { ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location), ast::ExprKind::Call { func, args, .. } if args.len() == 1 => { match &func.node { ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node { ast::ExprKind::Constant { value: Constant::Int(v), .. } => { let v: Result = (*v).try_into(); match v { Ok(v) => Ok(SymbolValue::I64(v)), _ => Err(format!("default param value out of range at {}", default.location)), } } _ => Err(format!("only allow constant integer here at {}", default.location)) } ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node { ast::ExprKind::Constant { value: Constant::Int(v), .. } => { let v: Result = (*v).try_into(); match v { Ok(v) => Ok(SymbolValue::U32(v)), _ => Err(format!("default param value out of range at {}", default.location)), } } _ => Err(format!("only allow constant integer here at {}", default.location)) } ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node { ast::ExprKind::Constant { value: Constant::Int(v), .. } => { let v: Result = (*v).try_into(); match v { Ok(v) => Ok(SymbolValue::U64(v)), _ => Err(format!("default param value out of range at {}", default.location)), } } _ => Err(format!("only allow constant integer here at {}", default.location)) } _ => Err(format!("unsupported default parameter at {}", default.location)), } } ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts .iter() .map(|x| parse_parameter_default_value(x, resolver)) .collect::, _>>()? )), ast::ExprKind::Name { id, .. } => { resolver.get_default_param_value(default).ok_or_else( || format!( "`{}` cannot be used as a default parameter at {} (not primitive type or tuple / not defined?)", id, default.location ) ) } _ => Err(format!("unsupported default parameter at {}", default.location)) } }