diff --git a/nac3core/src/typecheck/context.rs b/nac3core/src/typecheck/context.rs deleted file mode 100644 index 4c23c0d7..00000000 --- a/nac3core/src/typecheck/context.rs +++ /dev/null @@ -1,191 +0,0 @@ -use std::collections::HashMap; -use std::collections::HashSet; - -use super::primitives::get_var; -use super::symbol_resolver::*; -use super::typedef::*; -use rustpython_parser::ast::Location; - -/// Structure for storing top-level type definitions. -/// Used for collecting type signature from source code. -/// Can be converted to `InferenceContext` for type inference in functions. -#[derive(Clone)] -pub struct GlobalContext<'a> { - /// List of type definitions. - pub type_defs: Vec>, - /// List of type variable definitions. - pub var_defs: Vec>, -} - -impl<'a> GlobalContext<'a> { - pub fn new(type_defs: Vec>) -> GlobalContext { - GlobalContext { - type_defs, - var_defs: Vec::new(), - } - } - - pub fn add_type(&mut self, def: TypeDef<'a>) -> TypeId { - self.type_defs.push(def); - TypeId(self.type_defs.len() - 1) - } - - pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId { - self.var_defs.push(def); - VariableId(self.var_defs.len() - 1) - } - - pub fn get_type_def_mut(&mut self, id: TypeId) -> &mut TypeDef<'a> { - self.type_defs.get_mut(id.0).unwrap() - } - - pub fn get_type_def(&self, id: TypeId) -> &TypeDef { - self.type_defs.get(id.0).unwrap() - } - - pub fn get_var_def(&self, id: VariableId) -> &VarDef { - self.var_defs.get(id.0).unwrap() - } - - pub fn get_var_count(&self) -> usize { - self.var_defs.len() - } -} - -pub struct InferenceContext<'a> { - // a: (i, x) means that a.i = x - pub fields_assignment: HashMap>, - pub constraints: Vec<(Type, Type)>, - global: GlobalContext<'a>, - resolver: Box, - local_identifiers: HashMap<&'a str, Type>, - local_variables: Vec>, - fresh_var_id: usize, -} - -impl<'a> InferenceContext<'a> { - pub fn new( - global: GlobalContext<'a>, - resolver: Box, - ) -> InferenceContext<'a> { - let id = global.get_var_count(); - InferenceContext { - global, - fields_assignment: HashMap::new(), - constraints: Vec::new(), - resolver, - local_identifiers: HashMap::new(), - local_variables: Vec::new(), - fresh_var_id: id, - } - } - - fn get_fresh_var(&mut self) -> VariableId { - self.local_variables.push(VarDef { - name: None, - bound: Vec::new(), - }); - let id = self.fresh_var_id; - self.fresh_var_id += 1; - VariableId(id) - } - - fn get_fresh_var_with_bound(&mut self, bound: Vec) -> VariableId { - self.local_variables.push(VarDef { name: None, bound }); - let id = self.fresh_var_id; - self.fresh_var_id += 1; - VariableId(id) - } - - pub fn assign_identifier(&mut self, identifier: &'a str) -> Type { - if let Some(t) = self.local_identifiers.get(identifier) { - t.clone() - } else if let Some(SymbolType::Identifier(t)) = self.resolver.get_symbol_type(identifier) { - t - } else { - get_var(self.get_fresh_var()) - } - } - - pub fn get_identifier_type(&self, identifier: &'a str) -> Result { - if let Some(t) = self.local_identifiers.get(identifier) { - Ok(t.clone()) - } else if let Some(SymbolType::Identifier(t)) = self.resolver.get_symbol_type(identifier) { - Ok(t) - } else { - Err("unbounded identifier".into()) - } - } - - pub fn get_attribute_type( - &mut self, - expr: Type, - identifier: &'a str, - location: Location, - ) -> Result { - match expr.as_ref() { - TypeEnum::TypeVariable(id) => { - if !self.fields_assignment.contains_key(id) { - self.fields_assignment.insert(*id, Vec::new()); - } - let var_id = VariableId(self.fresh_var_id); - let entry = self.fields_assignment.get_mut(&id).unwrap(); - for (attr, t, _) in entry.iter() { - if *attr == identifier { - return Ok(get_var(*t)); - } - } - entry.push((identifier, var_id, location)); - self.local_variables.push(VarDef { - name: None, - bound: Vec::new(), - }); - self.fresh_var_id += 1; - Ok(get_var(var_id)) - } - TypeEnum::ClassType(id, params) => { - let type_def = self.global.get_type_def(*id); - let field = type_def - .base - .fields - .get(identifier) - .map_or_else(|| Err("no such field".to_owned()), Ok)?; - // function and tuple can have 0 type variables but with type parameters - // we require other types have the same number of type variables and type - // parameters in order to build a mapping - assert!(type_def.params.is_empty() || type_def.params.len() == params.len()); - let map = type_def - .params - .clone() - .into_iter() - .zip(params.clone().into_iter()) - .collect(); - let field = field.subst(&map); - Ok(self.get_instance(field)) - } - } - } - - fn get_instance(&mut self, t: Type) -> Type { - let mut vars = HashSet::new(); - t.get_vars(&mut vars); - - let local_min = self.global.get_var_count(); - let bounded = vars.into_iter().filter(|id| id.0 < local_min); - let map = bounded - .map(|v| { - ( - v, - get_var( - self.get_fresh_var_with_bound(self.global.get_var_def(v).bound.clone()), - ), - ) - }) - .collect(); - t.subst(&map) - } - - pub fn get_type_def(&self, id: TypeId) -> &TypeDef { - self.global.get_type_def(id) - } -} diff --git a/nac3core/src/typecheck/mod.rs b/nac3core/src/typecheck/mod.rs index 7b30426e..bb470cd0 100644 --- a/nac3core/src/typecheck/mod.rs +++ b/nac3core/src/typecheck/mod.rs @@ -1,8 +1,6 @@ #![allow(dead_code)] -// mod context; -// pub mod location; -// mod magic_methods; -// mod primitives; -// pub mod symbol_resolver; +pub mod location; +mod magic_methods; +pub mod symbol_resolver; mod test_typedef; pub mod typedef; diff --git a/nac3core/src/typecheck/primitives.rs b/nac3core/src/typecheck/primitives.rs deleted file mode 100644 index c383e955..00000000 --- a/nac3core/src/typecheck/primitives.rs +++ /dev/null @@ -1,168 +0,0 @@ -use super::context::*; -use super::typedef::{TypeEnum::*, *}; -use std::collections::HashMap; -use std::rc::Rc; - -pub const FUNC_TYPE: TypeId = TypeId(0); -pub const TUPLE_TYPE: TypeId = TypeId(1); -pub const LIST_TYPE: TypeId = TypeId(2); -pub const VIRTUAL_TYPE: TypeId = TypeId(3); -pub const NONE_TYPE: TypeId = TypeId(4); - -pub const BOOL_TYPE: TypeId = TypeId(5); -pub const INT32_TYPE: TypeId = TypeId(6); -pub const INT64_TYPE: TypeId = TypeId(7); -pub const FLOAT_TYPE: TypeId = TypeId(8); - -fn primitive(base: BaseDef) -> TypeDef { - TypeDef { - base, - parents: vec![], - params: vec![], - } -} - -pub fn get_fn(from: Type, to: Type) -> Type { - Rc::new(ClassType(FUNC_TYPE, vec![from, to])) -} - -pub fn get_tuple(types: &[Type]) -> Type { - Rc::new(ClassType(TUPLE_TYPE, types.to_vec())) -} - -pub fn get_list(t: Type) -> Type { - Rc::new(ClassType(LIST_TYPE, vec![t])) -} - -pub fn get_virtual(t: Type) -> Type { - Rc::new(ClassType(VIRTUAL_TYPE, vec![t])) -} - -pub fn get_none() -> Type { - Rc::new(ClassType(NONE_TYPE, Vec::new())) -} - -pub fn get_bool() -> Type { - Rc::new(ClassType(BOOL_TYPE, Vec::new())) -} -pub fn get_int32() -> Type { - Rc::new(ClassType(INT32_TYPE, Vec::new())) -} - -pub fn get_int64() -> Type { - Rc::new(ClassType(INT64_TYPE, Vec::new())) -} - -pub fn get_float() -> Type { - Rc::new(ClassType(FLOAT_TYPE, Vec::new())) -} - -pub fn get_var(id: VariableId) -> Type { - Rc::new(TypeVariable(id)) -} - -fn impl_math(def: &mut BaseDef, ty: &Type) { - let fun = get_fn(ty.clone(), ty.clone()); - def.fields.insert("__add__", fun.clone()); - def.fields.insert("__sub__", fun.clone()); - def.fields.insert("__mul__", fun.clone()); - def.fields.insert("__neg__", get_fn(get_none(), ty.clone())); - def.fields - .insert("__truediv__", get_fn(ty.clone(), get_float())); - def.fields.insert("__floordiv__", fun.clone()); - def.fields.insert("__mod__", fun.clone()); - def.fields.insert("__pow__", fun); -} - -fn impl_bits(def: &mut BaseDef, ty: &Type) { - let fun = get_fn(get_int32(), ty.clone()); - - def.fields.insert("__lshift__", fun.clone()); - def.fields.insert("__rshift__", fun); - def.fields.insert("__xor__", get_fn(ty.clone(), ty.clone())); -} - -fn impl_eq(def: &mut BaseDef, ty: &Type) { - let fun = get_fn(ty.clone(), get_bool()); - - def.fields.insert("__eq__", fun.clone()); - def.fields.insert("__ne__", fun); -} - -fn impl_order(def: &mut BaseDef, ty: &Type) { - let fun = get_fn(ty.clone(), get_bool()); - - def.fields.insert("__lt__", fun.clone()); - def.fields.insert("__gt__", fun.clone()); - def.fields.insert("__le__", fun.clone()); - def.fields.insert("__ge__", fun); -} - -pub fn basic_ctx() -> GlobalContext<'static> { - let mut ctx = GlobalContext::new(vec![ - primitive(BaseDef { - name: "function", - fields: HashMap::new(), - }), - primitive(BaseDef { - name: "tuple", - fields: HashMap::new(), - }), - primitive(BaseDef { - name: "list", - fields: HashMap::new(), - }), - primitive(BaseDef { - name: "virtual", - fields: HashMap::new(), - }), - primitive(BaseDef { - name: "None", - fields: HashMap::new(), - }), - primitive(BaseDef { - name: "bool", - fields: HashMap::new(), - }), - primitive(BaseDef { - name: "int32", - fields: HashMap::new(), - }), - primitive(BaseDef { - name: "int64", - fields: HashMap::new(), - }), - primitive(BaseDef { - name: "float", - fields: HashMap::new(), - }), - ]); - - let t = ctx.add_variable(VarDef { - name: Some("T"), - bound: vec![], - }); - ctx.get_type_def_mut(LIST_TYPE).params.push(t); - - let b_def = ctx.get_type_def_mut(BOOL_TYPE); - impl_eq(&mut b_def.base, &get_bool()); - let int32 = get_int32(); - let int32_def = &mut ctx.get_type_def_mut(INT32_TYPE).base; - impl_math(int32_def, &int32); - impl_bits(int32_def, &int32); - impl_order(int32_def, &int32); - impl_eq(int32_def, &int32); - let int64 = get_int64(); - let int64_def = &mut ctx.get_type_def_mut(INT64_TYPE).base; - impl_math(int64_def, &int64); - impl_bits(int64_def, &int64); - impl_order(int64_def, &int64); - impl_eq(int64_def, &int64); - let float = get_float(); - let float_def = &mut ctx.get_type_def_mut(FLOAT_TYPE).base; - impl_math(float_def, &float); - impl_order(float_def, &float); - impl_eq(float_def, &float); - - ctx -} diff --git a/nac3core/src/typecheck/typedef.rs b/nac3core/src/typecheck/typedef.rs index 6ff93879..e5680feb 100644 --- a/nac3core/src/typecheck/typedef.rs +++ b/nac3core/src/typecheck/typedef.rs @@ -436,7 +436,7 @@ impl Unifier { .collect(); for (i, t) in posargs.iter().enumerate() { if signature.args.len() <= i { - return Err(format!("Too many arguments.")); + return Err("Too many arguments.".to_string()); } if !required.is_empty() { required.pop(); @@ -465,17 +465,17 @@ impl Unifier { TypeEnum::TFunc(sign1) => { if let TypeEnum::TFunc(sign2) = &*ty_b { if !sign1.params.is_empty() || !sign2.params.is_empty() { - return Err(format!("Polymorphic function pointer is prohibited.")); + return Err("Polymorphic function pointer is prohibited.".to_string()); } if sign1.args.len() != sign2.args.len() { - return Err(format!("Functions differ in number of parameters.")); + return Err("Functions differ in number of parameters.".to_string()); } for (x, y) in sign1.args.iter().zip(sign2.args.iter()) { if x.name != y.name { - return Err(format!("Functions differ in parameter names.")); + return Err("Functions differ in parameter names.".to_string()); } if x.is_optional != y.is_optional { - return Err(format!("Functions differ in optional parameters.")); + return Err("Functions differ in optional parameters.".to_string()); } self.unify(x.ty, y.ty)?; } @@ -651,7 +651,7 @@ impl Unifier { let params = new_params.unwrap_or_else(|| params.clone()); let ret = new_ret.unwrap_or_else(|| *ret); let args = new_args.unwrap_or_else(|| args.clone()); - Some(self.add_ty(TypeEnum::TFunc(FunSignature { params, ret, args }))) + Some(self.add_ty(TypeEnum::TFunc(FunSignature { args, ret, params }))) } else { None }