From 2985b883519913ab40ee52d86cf8fe50ec52290e Mon Sep 17 00:00:00 2001 From: pca006132 Date: Wed, 30 Jun 2021 16:28:18 +0800 Subject: [PATCH] refactor for HM style inference... --- nac3core/src/typecheck/context.rs | 163 ++++++ .../src/typecheck/context/global_context.rs | 109 ---- .../typecheck/context/inference_context.rs | 202 ------- nac3core/src/typecheck/context/mod.rs | 4 - nac3core/src/typecheck/inference_core.rs | 525 ------------------ nac3core/src/typecheck/mod.rs | 7 +- nac3core/src/typecheck/primitives.rs | 276 +++++---- nac3core/src/typecheck/typedef.rs | 74 +-- 8 files changed, 335 insertions(+), 1025 deletions(-) create mode 100644 nac3core/src/typecheck/context.rs delete mode 100644 nac3core/src/typecheck/context/global_context.rs delete mode 100644 nac3core/src/typecheck/context/inference_context.rs delete mode 100644 nac3core/src/typecheck/context/mod.rs delete mode 100644 nac3core/src/typecheck/inference_core.rs diff --git a/nac3core/src/typecheck/context.rs b/nac3core/src/typecheck/context.rs new file mode 100644 index 00000000..c1b8bb07 --- /dev/null +++ b/nac3core/src/typecheck/context.rs @@ -0,0 +1,163 @@ +use std::collections::HashMap; + +use super::primitives::get_var; +use super::symbol_resolver::*; +use super::typedef::*; +use rustpython_parser::ast::Location; + +/// Structure for storing top-level type definitions. +/// Used for collecting type signature from source code. +/// Can be converted to `InferenceContext` for type inference in functions. +#[derive(Clone)] +pub struct GlobalContext<'a> { + /// List of type definitions. + pub type_defs: Vec>, + /// List of type variable definitions. + pub var_defs: Vec>, +} + +impl<'a> GlobalContext<'a> { + pub fn new(type_defs: Vec>) -> GlobalContext { + GlobalContext { + type_defs, + var_defs: Vec::new(), + } + } + + pub fn add_type(&mut self, def: TypeDef<'a>) -> TypeId { + self.type_defs.push(def); + TypeId(self.type_defs.len() - 1) + } + + pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId { + self.var_defs.push(def); + VariableId(self.var_defs.len() - 1) + } + + pub fn get_type_def_mut(&mut self, id: TypeId) -> &mut TypeDef<'a> { + self.type_defs.get_mut(id.0).unwrap() + } + + pub fn get_type_def(&self, id: TypeId) -> &TypeDef { + self.type_defs.get(id.0).unwrap() + } + + pub fn get_var_def(&self, id: VariableId) -> &VarDef { + self.var_defs.get(id.0).unwrap() + } + + pub fn get_var_count(&self) -> usize { + self.var_defs.len() + } +} + +pub struct InferenceContext<'a> { + // a: (i, x) means that a.i = x + pub fields_assignment: HashMap>, + pub constraints: Vec<(Type, Type)>, + global: GlobalContext<'a>, + resolver: Box, + local_identifiers: HashMap<&'a str, Type>, + local_variables: Vec>, + fresh_var_id: usize, +} + +impl<'a> InferenceContext<'a> { + pub fn new( + global: GlobalContext<'a>, + resolver: Box, + ) -> InferenceContext<'a> { + let id = global.get_var_count(); + InferenceContext { + global, + fields_assignment: HashMap::new(), + constraints: Vec::new(), + resolver, + local_identifiers: HashMap::new(), + local_variables: Vec::new(), + fresh_var_id: id, + } + } + + fn get_fresh_var(&mut self) -> VariableId { + self.local_variables.push(VarDef { + name: None, + bound: Vec::new(), + }); + let id = self.fresh_var_id; + self.fresh_var_id += 1; + VariableId(id) + } + + pub fn assign_identifier(&mut self, identifier: &'a str) -> Type { + if let Some(t) = self.local_identifiers.get(identifier) { + t.clone() + } else if let Some(SymbolType::Identifier(t)) = self.resolver.get_symbol_type(identifier) { + t + } else { + get_var(self.get_fresh_var()) + } + } + + pub fn get_identifier_type(&self, identifier: &'a str) -> Result { + if let Some(t) = self.local_identifiers.get(identifier) { + Ok(t.clone()) + } else if let Some(SymbolType::Identifier(t)) = self.resolver.get_symbol_type(identifier) { + Ok(t) + } else { + Err("unbounded identifier".into()) + } + } + + pub fn get_attribute_type( + &mut self, + expr: Type, + identifier: &'a str, + location: Location, + ) -> Result { + match expr.as_ref() { + TypeEnum::TypeVariable(id) => { + if !self.fields_assignment.contains_key(id) { + self.fields_assignment.insert(*id, Vec::new()); + } + let var_id = VariableId(self.fresh_var_id); + let entry = self.fields_assignment.get_mut(&id).unwrap(); + for (attr, t, _) in entry.iter() { + if *attr == identifier { + return Ok(get_var(*t)); + } + } + entry.push((identifier, var_id, location)); + self.local_variables.push(VarDef { + name: None, + bound: Vec::new(), + }); + self.fresh_var_id += 1; + Ok(get_var(var_id)) + } + TypeEnum::ClassType(id, params) => { + let type_def = self.global.get_type_def(*id); + let field = type_def + .base + .fields + .get(identifier) + .map_or_else(|| Err("no such field".to_owned()), |v| Ok(v))?; + // function and tuple can have 0 type variables but with type parameters + // we require other types have the same number of type variables and type + // parameters in order to build a mapping + assert!(type_def.params.len() == 0 || type_def.params.len() == params.len()); + let map = type_def + .params + .clone() + .into_iter() + .zip(params.clone().into_iter()) + .collect(); + Ok(field.subst(&map)) + } + } + } + + pub fn get_type_def(&self, id: TypeId) -> &TypeDef { + self.global.get_type_def(id) + } +} diff --git a/nac3core/src/typecheck/context/global_context.rs b/nac3core/src/typecheck/context/global_context.rs deleted file mode 100644 index b322d7ea..00000000 --- a/nac3core/src/typecheck/context/global_context.rs +++ /dev/null @@ -1,109 +0,0 @@ -use super::super::typedef::*; -use std::collections::HashMap; -use std::rc::Rc; - -/// Structure for storing top-level type definitions. -/// Used for collecting type signature from source code. -/// Can be converted to `InferenceContext` for type inference in functions. -pub struct GlobalContext<'a> { - /// List of primitive definitions. - pub(super) primitive_defs: Vec>, - /// List of class definitions. - pub(super) class_defs: Vec>, - /// List of parametric type definitions. - pub(super) parametric_defs: Vec>, - /// List of type variable definitions. - pub(super) var_defs: Vec>, - /// Function name to signature mapping. - pub(super) fn_table: HashMap<&'a str, FnDef>, - - primitives: Vec, - variables: Vec, -} - -impl<'a> GlobalContext<'a> { - pub fn new(primitive_defs: Vec>) -> GlobalContext { - let mut primitives = Vec::new(); - for (i, t) in primitive_defs.iter().enumerate() { - primitives.push(TypeEnum::PrimitiveType(PrimitiveId(i)).into()); - } - GlobalContext { - primitive_defs, - class_defs: Vec::new(), - parametric_defs: Vec::new(), - var_defs: Vec::new(), - fn_table: HashMap::new(), - primitives, - variables: Vec::new(), - } - } - - pub fn add_class(&mut self, def: ClassDef<'a>) -> ClassId { - self.class_defs.push(def); - ClassId(self.class_defs.len() - 1) - } - - pub fn add_parametric(&mut self, def: ParametricDef<'a>) -> ParamId { - self.parametric_defs.push(def); - ParamId(self.parametric_defs.len() - 1) - } - - pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId { - self.add_variable_private(def) - } - - pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId { - self.var_defs.push(def); - self.variables - .push(TypeEnum::TypeVariable(VariableId(self.var_defs.len() - 1)).into()); - VariableId(self.var_defs.len() - 1) - } - - pub fn add_fn(&mut self, name: &'a str, def: FnDef) { - self.fn_table.insert(name, def); - } - - pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> { - self.fn_table.get(name) - } - - pub fn get_primitive_def_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> { - self.primitive_defs.get_mut(id.0).unwrap() - } - - pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef { - self.primitive_defs.get(id.0).unwrap() - } - - pub fn get_class_def_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> { - self.class_defs.get_mut(id.0).unwrap() - } - - pub fn get_class_def(&self, id: ClassId) -> &ClassDef { - self.class_defs.get(id.0).unwrap() - } - - pub fn get_parametric_def_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> { - self.parametric_defs.get_mut(id.0).unwrap() - } - - pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef { - self.parametric_defs.get(id.0).unwrap() - } - - pub fn get_variable_def_mut(&mut self, id: VariableId) -> &mut VarDef<'a> { - self.var_defs.get_mut(id.0).unwrap() - } - - pub fn get_variable_def(&self, id: VariableId) -> &VarDef { - self.var_defs.get(id.0).unwrap() - } - - pub fn get_primitive(&self, id: PrimitiveId) -> Type { - self.primitives.get(id.0).unwrap().clone() - } - - pub fn get_variable(&self, id: VariableId) -> Type { - self.variables.get(id.0).unwrap().clone() - } -} diff --git a/nac3core/src/typecheck/context/inference_context.rs b/nac3core/src/typecheck/context/inference_context.rs deleted file mode 100644 index d1c76a36..00000000 --- a/nac3core/src/typecheck/context/inference_context.rs +++ /dev/null @@ -1,202 +0,0 @@ -use super::super::location::{FileID, Location}; -use super::super::symbol_resolver::*; -use super::super::typedef::*; -use super::GlobalContext; -use rustpython_parser::ast; -use std::boxed::Box; -use std::collections::HashMap; - -struct ContextStack<'a> { - /// stack level, starts from 0 - level: u32, - /// stack of symbol definitions containing (name, level) where `level` is the smallest level - /// where the name is assigned a value - sym_def: Vec<(&'a str, u32)>, -} - -pub struct InferenceContext<'a> { - /// global context - global: GlobalContext<'a>, - /// per source symbol resolver - resolver: Box, - /// File ID - file: FileID, - - /// identifier to (type, readable) mapping. - /// an identifier might be defined earlier but has no value (for some code path), thus not - /// readable. - sym_table: HashMap<&'a str, (Type, bool, Location)>, - /// stack - stack: ContextStack<'a>, -} - -// non-trivial implementations here -impl<'a> InferenceContext<'a> { - pub fn new( - global: GlobalContext, - resolver: Box, - file: FileID, - ) -> InferenceContext { - InferenceContext { - global, - resolver, - file, - sym_table: HashMap::new(), - stack: ContextStack { - level: 0, - sym_def: Vec::new(), - }, - } - } - - /// execute the function with new scope. - /// variable assignment would be limited within the scope (not readable outside), and type - /// returns the list of variables assigned within the scope, and the result of the function - pub fn with_scope(&mut self, f: F) -> (Vec<(&'a str, Type, Location)>, R) - where - F: FnOnce(&mut Self) -> R, - { - self.stack.level += 1; - let result = f(self); - self.stack.level -= 1; - let mut poped_names = Vec::new(); - while !self.stack.sym_def.is_empty() { - let (_, level) = self.stack.sym_def.last().unwrap(); - if *level > self.stack.level { - let (name, _) = self.stack.sym_def.pop().unwrap(); - let (t, b, l) = self.sym_table.get_mut(name).unwrap(); - // set it to be unreadable - *b = false; - poped_names.push((name, t.clone(), *l)); - } else { - break; - } - } - (poped_names, result) - } - - /// assign a type to an identifier. - /// may return error if the identifier was defined but with different type - pub fn assign(&mut self, name: &'a str, ty: Type, loc: ast::Location) -> Result { - if let Some((t, x, _)) = self.sym_table.get_mut(name) { - if t == &ty { - if !*x { - self.stack.sym_def.push((name, self.stack.level)); - } - *x = true; - Ok(ty) - } else { - Err("different types".into()) - } - } else { - self.stack.sym_def.push((name, self.stack.level)); - self.sym_table.insert( - name, - (ty.clone(), true, Location::CodeRange(self.file, loc)), - ); - Ok(ty) - } - } - - /// get the type of an identifier - /// may return error if the identifier is not defined, and cannot be resolved with the - /// resolution function. - pub fn resolve(&self, name: &str) -> Result { - if let Some((t, x, _)) = self.sym_table.get(name) { - if *x { - Ok(t.clone()) - } else { - Err("may not be defined".into()) - } - } else { - match self.resolver.get_symbol_type(name) { - Some(SymbolType::Identifier(t)) => Ok(t), - Some(SymbolType::TypeName(_)) => Err("is not a value".into()), - _ => Err("unbounded identifier".into()), - } - } - } - - pub fn get_location(&self, name: &str) -> Option { - if let Some((_, _, l)) = self.sym_table.get(name) { - Some(*l) - } else { - self.resolver.get_symbol_location(name) - } - } -} - -// trivial getters: -impl<'a> InferenceContext<'a> { - pub fn get_primitive(&self, id: PrimitiveId) -> Type { - TypeEnum::PrimitiveType(id).into() - } - - pub fn get_variable(&self, id: VariableId) -> Type { - TypeEnum::TypeVariable(id).into() - } - - pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> { - self.global.fn_table.get(name) - } - pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef { - self.global.primitive_defs.get(id.0).unwrap() - } - pub fn get_class_def(&self, id: ClassId) -> &ClassDef { - self.global.class_defs.get(id.0).unwrap() - } - pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef { - self.global.parametric_defs.get(id.0).unwrap() - } - pub fn get_variable_def(&self, id: VariableId) -> &VarDef { - self.global.var_defs.get(id.0).unwrap() - } - pub fn get_type(&self, name: &str) -> Result { - match self.resolver.get_symbol_type(name) { - Some(SymbolType::TypeName(t)) => Ok(t), - Some(SymbolType::Identifier(_)) => Err("not a type".into()), - _ => Err("unbounded identifier".into()), - } - } -} - -impl TypeEnum { - pub fn subst(&self, map: &HashMap) -> TypeEnum { - match self { - TypeEnum::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(), - TypeEnum::ParametricType(id, params) => TypeEnum::ParametricType( - *id, - params - .iter() - .map(|v| v.as_ref().subst(map).into()) - .collect(), - ), - _ => self.clone(), - } - } - - pub fn get_subst(&self, ctx: &InferenceContext) -> HashMap { - match self { - TypeEnum::ParametricType(id, params) => { - let vars = &ctx.get_parametric_def(*id).params; - vars.iter() - .zip(params) - .map(|(v, p)| (*v, p.as_ref().clone().into())) - .collect() - } - // if this proves to be slow, we can use option type - _ => HashMap::new(), - } - } - - pub fn get_base<'b: 'a, 'a>(&'a self, ctx: &'b InferenceContext) -> Option<&'b TypeDef> { - match self { - TypeEnum::PrimitiveType(id) => Some(ctx.get_primitive_def(*id)), - TypeEnum::ClassType(id) | TypeEnum::VirtualClassType(id) => { - Some(&ctx.get_class_def(*id).base) - } - TypeEnum::ParametricType(id, _) => Some(&ctx.get_parametric_def(*id).base), - _ => None, - } - } -} diff --git a/nac3core/src/typecheck/context/mod.rs b/nac3core/src/typecheck/context/mod.rs deleted file mode 100644 index 3a5d8d11..00000000 --- a/nac3core/src/typecheck/context/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod inference_context; -mod global_context; -pub use inference_context::InferenceContext; -pub use global_context::GlobalContext; diff --git a/nac3core/src/typecheck/inference_core.rs b/nac3core/src/typecheck/inference_core.rs deleted file mode 100644 index 679c04c8..00000000 --- a/nac3core/src/typecheck/inference_core.rs +++ /dev/null @@ -1,525 +0,0 @@ -use super::context::InferenceContext; -use super::typedef::{TypeEnum::*, *}; -use std::collections::HashMap; - -fn find_subst( - ctx: &InferenceContext, - valuation: &Option<(VariableId, Type)>, - sub: &mut HashMap, - mut a: Type, - mut b: Type, -) -> Result<(), String> { - // TODO: fix error messages later - if let TypeVariable(id) = a.as_ref() { - if let Some((assumption_id, t)) = valuation { - if assumption_id == id { - a = t.clone(); - } - } - } - - let mut substituted = false; - if let TypeVariable(id) = b.as_ref() { - if let Some(c) = sub.get(&id) { - b = c.clone(); - substituted = true; - } - } - - match (a.as_ref(), b.as_ref()) { - (BotType, _) => Ok(()), - (TypeVariable(id_a), TypeVariable(id_b)) => { - if substituted { - return if id_a == id_b { - Ok(()) - } else { - Err("different variables".to_string()) - }; - } - let v_a = ctx.get_variable_def(*id_a); - let v_b = ctx.get_variable_def(*id_b); - if !v_b.bound.is_empty() { - if v_a.bound.is_empty() { - return Err("unbounded a".to_string()); - } else { - let diff: Vec<_> = v_a - .bound - .iter() - .filter(|x| !v_b.bound.contains(x)) - .collect(); - if !diff.is_empty() { - return Err("different domain".to_string()); - } - } - } - sub.insert(*id_b, a.clone()); - Ok(()) - } - (TypeVariable(id_a), _) => { - let v_a = ctx.get_variable_def(*id_a); - if v_a.bound.len() == 1 && v_a.bound[0].as_ref() == b.as_ref() { - Ok(()) - } else { - Err("different domain".to_string()) - } - } - (_, TypeVariable(id_b)) => { - let v_b = ctx.get_variable_def(*id_b); - if v_b.bound.is_empty() || v_b.bound.contains(&a) { - sub.insert(*id_b, a.clone()); - Ok(()) - } else { - Err("different domain".to_string()) - } - } - (_, VirtualClassType(id_b)) => { - let mut parents; - match a.as_ref() { - ClassType(id_a) => { - parents = [*id_a].to_vec(); - } - VirtualClassType(id_a) => { - parents = [*id_a].to_vec(); - } - _ => { - return Err("cannot substitute non-class type into virtual class".to_string()); - } - }; - while !parents.is_empty() { - if *id_b == parents[0] { - return Ok(()); - } - let c = ctx.get_class_def(parents.remove(0)); - parents.extend_from_slice(&c.parents); - } - Err("not subtype".to_string()) - } - (ParametricType(id_a, param_a), ParametricType(id_b, param_b)) => { - if id_a != id_b || param_a.len() != param_b.len() { - Err("different parametric types".to_string()) - } else { - for (x, y) in param_a.iter().zip(param_b.iter()) { - find_subst(ctx, valuation, sub, x.clone(), y.clone())?; - } - Ok(()) - } - } - (_, _) => { - if a == b { - Ok(()) - } else { - Err("not equal".to_string()) - } - } - } -} - -fn resolve_call_rec( - ctx: &InferenceContext, - valuation: &Option<(VariableId, Type)>, - obj: Option, - func: &str, - args: &[Type], -) -> Result, String> { - let mut subst = obj - .as_ref() - .map(|v| v.get_subst(ctx)) - .unwrap_or_else(HashMap::new); - - let fun = match &obj { - Some(obj) => { - let base = match obj.as_ref() { - PrimitiveType(id) => &ctx.get_primitive_def(*id), - ClassType(id) | VirtualClassType(id) => &ctx.get_class_def(*id).base, - ParametricType(id, _) => &ctx.get_parametric_def(*id).base, - _ => return Err("not supported".to_string()), - }; - base.methods.get(func) - } - None => ctx.get_fn_def(func), - } - .ok_or_else(|| "no such function".to_string())?; - - if args.len() != fun.args.len() { - return Err("incorrect parameter number".to_string()); - } - for (a, b) in args.iter().zip(fun.args.iter()) { - find_subst(ctx, valuation, &mut subst, a.clone(), b.clone())?; - } - let result = fun.result.as_ref().map(|v| v.subst(&subst)); - Ok(result.map(|result| { - if let SelfType = result { - obj.unwrap() - } else { - result.into() - } - })) -} - -pub fn resolve_call( - ctx: &InferenceContext, - obj: Option, - func: &str, - args: &[Type], -) -> Result, String> { - resolve_call_rec(ctx, &None, obj, func, args) -} - -#[cfg(test)] -mod tests { - use super::*; - use super::super::context::GlobalContext; - use super::super::primitives::*; - use std::rc::Rc; - - fn get_inference_context(ctx: GlobalContext) -> InferenceContext { - InferenceContext::new(ctx, Box::new(|_| Err("unbounded identifier".into()))) - } - - #[test] - fn test_simple_generic() { - let mut ctx = basic_ctx(); - let v1 = ctx.add_variable(VarDef { - name: "V1", - bound: vec![ctx.get_primitive(INT32_TYPE), ctx.get_primitive(FLOAT_TYPE)], - }); - let v1 = ctx.get_variable(v1); - let v2 = ctx.add_variable(VarDef { - name: "V2", - bound: vec![ - ctx.get_primitive(BOOL_TYPE), - ctx.get_primitive(INT32_TYPE), - ctx.get_primitive(FLOAT_TYPE), - ], - }); - let v2 = ctx.get_variable(v2); - let ctx = get_inference_context(ctx); - - assert_eq!( - resolve_call(&ctx, None, "int32", &[ctx.get_primitive(FLOAT_TYPE)]), - Ok(Some(ctx.get_primitive(INT32_TYPE))) - ); - - assert_eq!( - resolve_call(&ctx, None, "int32", &[ctx.get_primitive(INT32_TYPE)],), - Ok(Some(ctx.get_primitive(INT32_TYPE))) - ); - - assert_eq!( - resolve_call(&ctx, None, "float", &[ctx.get_primitive(INT32_TYPE)]), - Ok(Some(ctx.get_primitive(FLOAT_TYPE))) - ); - - assert_eq!( - resolve_call(&ctx, None, "float", &[ctx.get_primitive(BOOL_TYPE)]), - Err("different domain".to_string()) - ); - - assert_eq!( - resolve_call(&ctx, None, "float", &[]), - Err("incorrect parameter number".to_string()) - ); - - assert_eq!( - resolve_call(&ctx, None, "float", &[v1]), - Ok(Some(ctx.get_primitive(FLOAT_TYPE))) - ); - - assert_eq!( - resolve_call(&ctx, None, "float", &[v2]), - Err("different domain".to_string()) - ); - } - - #[test] - fn test_methods() { - let mut ctx = basic_ctx(); - - let v0 = ctx.add_variable(VarDef { - name: "V0", - bound: vec![], - }); - let v0 = ctx.get_variable(v0); - - let int32 = ctx.get_primitive(INT32_TYPE); - let int64 = ctx.get_primitive(INT64_TYPE); - let ctx = get_inference_context(ctx); - - // simple cases - assert_eq!( - resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]), - Ok(Some(int32.clone())) - ); - - assert_ne!( - resolve_call(&ctx, Some(int32.clone()), "__add__", &[int32.clone()]), - Ok(Some(int64.clone())) - ); - - assert_eq!( - resolve_call(&ctx, Some(int32), "__add__", &[int64]), - Err("not equal".to_string()) - ); - - // with type variables - assert_eq!( - resolve_call(&ctx, Some(v0.clone()), "__add__", &[v0.clone()]), - Err("not supported".into()) - ); - } - - #[test] - fn test_multi_generic() { - let mut ctx = basic_ctx(); - let v0 = ctx.add_variable(VarDef { - name: "V0", - bound: vec![], - }); - let v0 = ctx.get_variable(v0); - let v1 = ctx.add_variable(VarDef { - name: "V1", - bound: vec![], - }); - let v1 = ctx.get_variable(v1); - let v2 = ctx.add_variable(VarDef { - name: "V2", - bound: vec![], - }); - let v2 = ctx.get_variable(v2); - let v3 = ctx.add_variable(VarDef { - name: "V3", - bound: vec![], - }); - let v3 = ctx.get_variable(v3); - - ctx.add_fn( - "foo", - FnDef { - args: vec![v0.clone(), v0.clone(), v1.clone()], - result: Some(v0.clone()), - }, - ); - - ctx.add_fn( - "foo1", - FnDef { - args: vec![ParametricType(TUPLE_TYPE, vec![v0.clone(), v0.clone(), v1]).into()], - result: Some(v0), - }, - ); - let ctx = get_inference_context(ctx); - - assert_eq!( - resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v2.clone()]), - Ok(Some(v2.clone())) - ); - assert_eq!( - resolve_call(&ctx, None, "foo", &[v2.clone(), v2.clone(), v3.clone()]), - Ok(Some(v2.clone())) - ); - assert_eq!( - resolve_call(&ctx, None, "foo", &[v2.clone(), v3.clone(), v3.clone()]), - Err("different variables".to_string()) - ); - - assert_eq!( - resolve_call( - &ctx, - None, - "foo1", - &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v2.clone()]).into()] - ), - Ok(Some(v2.clone())) - ); - assert_eq!( - resolve_call( - &ctx, - None, - "foo1", - &[ParametricType(TUPLE_TYPE, vec![v2.clone(), v2.clone(), v3.clone()]).into()] - ), - Ok(Some(v2.clone())) - ); - assert_eq!( - resolve_call( - &ctx, - None, - "foo1", - &[ParametricType(TUPLE_TYPE, vec![v2, v3.clone(), v3]).into()] - ), - Err("different variables".to_string()) - ); - } - - #[test] - fn test_class_generics() { - let mut ctx = basic_ctx(); - - let list = ctx.get_parametric_def_mut(LIST_TYPE); - let t = Rc::new(TypeVariable(list.params[0])); - list.base.methods.insert( - "head", - FnDef { - args: vec![], - result: Some(t.clone()), - }, - ); - list.base.methods.insert( - "append", - FnDef { - args: vec![t], - result: None, - }, - ); - - let v0 = ctx.add_variable(VarDef { - name: "V0", - bound: vec![], - }); - let v0 = ctx.get_variable(v0); - let v1 = ctx.add_variable(VarDef { - name: "V1", - bound: vec![], - }); - let v1 = ctx.get_variable(v1); - let ctx = get_inference_context(ctx); - - assert_eq!( - resolve_call( - &ctx, - Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), - "head", - &[] - ), - Ok(Some(v0.clone())) - ); - assert_eq!( - resolve_call( - &ctx, - Some(ParametricType(LIST_TYPE, vec![v0.clone()]).into()), - "append", - &[v0.clone()] - ), - Ok(None) - ); - assert_eq!( - resolve_call( - &ctx, - Some(ParametricType(LIST_TYPE, vec![v0]).into()), - "append", - &[v1] - ), - Err("different variables".to_string()) - ); - } - - #[test] - fn test_virtual_class() { - let mut ctx = basic_ctx(); - - let foo = ctx.add_class(ClassDef { - base: TypeDef { - name: "Foo", - methods: HashMap::new(), - fields: HashMap::new(), - }, - parents: vec![], - }); - - let foo1 = ctx.add_class(ClassDef { - base: TypeDef { - name: "Foo1", - methods: HashMap::new(), - fields: HashMap::new(), - }, - parents: vec![foo], - }); - - let foo2 = ctx.add_class(ClassDef { - base: TypeDef { - name: "Foo2", - methods: HashMap::new(), - fields: HashMap::new(), - }, - parents: vec![foo1], - }); - - let bar = ctx.add_class(ClassDef { - base: TypeDef { - name: "bar", - methods: HashMap::new(), - fields: HashMap::new(), - }, - parents: vec![], - }); - - ctx.add_fn( - "foo", - FnDef { - args: vec![VirtualClassType(foo).into()], - result: None, - }, - ); - ctx.add_fn( - "foo1", - FnDef { - args: vec![VirtualClassType(foo1).into()], - result: None, - }, - ); - let ctx = get_inference_context(ctx); - - assert_eq!( - resolve_call(&ctx, None, "foo", &[ClassType(foo).into()]), - Ok(None) - ); - - assert_eq!( - resolve_call(&ctx, None, "foo", &[ClassType(foo1).into()]), - Ok(None) - ); - - assert_eq!( - resolve_call(&ctx, None, "foo", &[ClassType(foo2).into()]), - Ok(None) - ); - - assert_eq!( - resolve_call(&ctx, None, "foo", &[ClassType(bar).into()]), - Err("not subtype".to_string()) - ); - - assert_eq!( - resolve_call(&ctx, None, "foo1", &[ClassType(foo1).into()]), - Ok(None) - ); - - assert_eq!( - resolve_call(&ctx, None, "foo1", &[ClassType(foo2).into()]), - Ok(None) - ); - - assert_eq!( - resolve_call(&ctx, None, "foo1", &[ClassType(foo).into()]), - Err("not subtype".to_string()) - ); - - // virtual class substitution - assert_eq!( - resolve_call(&ctx, None, "foo", &[VirtualClassType(foo).into()]), - Ok(None) - ); - assert_eq!( - resolve_call(&ctx, None, "foo", &[VirtualClassType(foo1).into()]), - Ok(None) - ); - assert_eq!( - resolve_call(&ctx, None, "foo", &[VirtualClassType(foo2).into()]), - Ok(None) - ); - assert_eq!( - resolve_call(&ctx, None, "foo", &[VirtualClassType(bar).into()]), - Err("not subtype".to_string()) - ); - } -} diff --git a/nac3core/src/typecheck/mod.rs b/nac3core/src/typecheck/mod.rs index a3be5925..7ab82585 100644 --- a/nac3core/src/typecheck/mod.rs +++ b/nac3core/src/typecheck/mod.rs @@ -1,7 +1,6 @@ -pub mod context; -pub mod inference_core; +mod context; pub mod location; -pub mod magic_methods; -pub mod primitives; +mod magic_methods; +mod primitives; pub mod symbol_resolver; pub mod typedef; diff --git a/nac3core/src/typecheck/primitives.rs b/nac3core/src/typecheck/primitives.rs index 94e76ee7..c383e955 100644 --- a/nac3core/src/typecheck/primitives.rs +++ b/nac3core/src/typecheck/primitives.rs @@ -1,184 +1,168 @@ -use super::typedef::{TypeEnum::*, *}; use super::context::*; +use super::typedef::{TypeEnum::*, *}; use std::collections::HashMap; +use std::rc::Rc; -pub const TUPLE_TYPE: ParamId = ParamId(0); -pub const LIST_TYPE: ParamId = ParamId(1); +pub const FUNC_TYPE: TypeId = TypeId(0); +pub const TUPLE_TYPE: TypeId = TypeId(1); +pub const LIST_TYPE: TypeId = TypeId(2); +pub const VIRTUAL_TYPE: TypeId = TypeId(3); +pub const NONE_TYPE: TypeId = TypeId(4); -pub const BOOL_TYPE: PrimitiveId = PrimitiveId(0); -pub const INT32_TYPE: PrimitiveId = PrimitiveId(1); -pub const INT64_TYPE: PrimitiveId = PrimitiveId(2); -pub const FLOAT_TYPE: PrimitiveId = PrimitiveId(3); +pub const BOOL_TYPE: TypeId = TypeId(5); +pub const INT32_TYPE: TypeId = TypeId(6); +pub const INT64_TYPE: TypeId = TypeId(7); +pub const FLOAT_TYPE: TypeId = TypeId(8); -fn impl_math(def: &mut TypeDef, ty: &Type) { - let result = Some(ty.clone()); - let fun = FnDef { - args: vec![ty.clone()], - result: result.clone(), - }; - def.methods.insert("__add__", fun.clone()); - def.methods.insert("__sub__", fun.clone()); - def.methods.insert("__mul__", fun.clone()); - def.methods.insert( - "__neg__", - FnDef { - args: vec![], - result, - }, - ); - def.methods.insert( - "__truediv__", - FnDef { - args: vec![ty.clone()], - result: Some(PrimitiveType(FLOAT_TYPE).into()), - }, - ); - def.methods.insert("__floordiv__", fun.clone()); - def.methods.insert("__mod__", fun.clone()); - def.methods.insert("__pow__", fun); +fn primitive(base: BaseDef) -> TypeDef { + TypeDef { + base, + parents: vec![], + params: vec![], + } } -fn impl_bits(def: &mut TypeDef, ty: &Type) { - let result = Some(ty.clone()); - let fun = FnDef { - args: vec![PrimitiveType(INT32_TYPE).into()], - result, - }; - - def.methods.insert("__lshift__", fun.clone()); - def.methods.insert("__rshift__", fun); - def.methods.insert( - "__xor__", - FnDef { - args: vec![ty.clone()], - result: Some(ty.clone()), - }, - ); +pub fn get_fn(from: Type, to: Type) -> Type { + Rc::new(ClassType(FUNC_TYPE, vec![from, to])) } -fn impl_eq(def: &mut TypeDef, ty: &Type) { - let fun = FnDef { - args: vec![ty.clone()], - result: Some(PrimitiveType(BOOL_TYPE).into()), - }; - - def.methods.insert("__eq__", fun.clone()); - def.methods.insert("__ne__", fun); +pub fn get_tuple(types: &[Type]) -> Type { + Rc::new(ClassType(TUPLE_TYPE, types.to_vec())) } -fn impl_order(def: &mut TypeDef, ty: &Type) { - let fun = FnDef { - args: vec![ty.clone()], - result: Some(PrimitiveType(BOOL_TYPE).into()), - }; +pub fn get_list(t: Type) -> Type { + Rc::new(ClassType(LIST_TYPE, vec![t])) +} - def.methods.insert("__lt__", fun.clone()); - def.methods.insert("__gt__", fun.clone()); - def.methods.insert("__le__", fun.clone()); - def.methods.insert("__ge__", fun); +pub fn get_virtual(t: Type) -> Type { + Rc::new(ClassType(VIRTUAL_TYPE, vec![t])) +} + +pub fn get_none() -> Type { + Rc::new(ClassType(NONE_TYPE, Vec::new())) +} + +pub fn get_bool() -> Type { + Rc::new(ClassType(BOOL_TYPE, Vec::new())) +} +pub fn get_int32() -> Type { + Rc::new(ClassType(INT32_TYPE, Vec::new())) +} + +pub fn get_int64() -> Type { + Rc::new(ClassType(INT64_TYPE, Vec::new())) +} + +pub fn get_float() -> Type { + Rc::new(ClassType(FLOAT_TYPE, Vec::new())) +} + +pub fn get_var(id: VariableId) -> Type { + Rc::new(TypeVariable(id)) +} + +fn impl_math(def: &mut BaseDef, ty: &Type) { + let fun = get_fn(ty.clone(), ty.clone()); + def.fields.insert("__add__", fun.clone()); + def.fields.insert("__sub__", fun.clone()); + def.fields.insert("__mul__", fun.clone()); + def.fields.insert("__neg__", get_fn(get_none(), ty.clone())); + def.fields + .insert("__truediv__", get_fn(ty.clone(), get_float())); + def.fields.insert("__floordiv__", fun.clone()); + def.fields.insert("__mod__", fun.clone()); + def.fields.insert("__pow__", fun); +} + +fn impl_bits(def: &mut BaseDef, ty: &Type) { + let fun = get_fn(get_int32(), ty.clone()); + + def.fields.insert("__lshift__", fun.clone()); + def.fields.insert("__rshift__", fun); + def.fields.insert("__xor__", get_fn(ty.clone(), ty.clone())); +} + +fn impl_eq(def: &mut BaseDef, ty: &Type) { + let fun = get_fn(ty.clone(), get_bool()); + + def.fields.insert("__eq__", fun.clone()); + def.fields.insert("__ne__", fun); +} + +fn impl_order(def: &mut BaseDef, ty: &Type) { + let fun = get_fn(ty.clone(), get_bool()); + + def.fields.insert("__lt__", fun.clone()); + def.fields.insert("__gt__", fun.clone()); + def.fields.insert("__le__", fun.clone()); + def.fields.insert("__ge__", fun); } pub fn basic_ctx() -> GlobalContext<'static> { - let primitives = [ - TypeDef { + let mut ctx = GlobalContext::new(vec![ + primitive(BaseDef { + name: "function", + fields: HashMap::new(), + }), + primitive(BaseDef { + name: "tuple", + fields: HashMap::new(), + }), + primitive(BaseDef { + name: "list", + fields: HashMap::new(), + }), + primitive(BaseDef { + name: "virtual", + fields: HashMap::new(), + }), + primitive(BaseDef { + name: "None", + fields: HashMap::new(), + }), + primitive(BaseDef { name: "bool", fields: HashMap::new(), - methods: HashMap::new(), - }, - TypeDef { + }), + primitive(BaseDef { name: "int32", fields: HashMap::new(), - methods: HashMap::new(), - }, - TypeDef { + }), + primitive(BaseDef { name: "int64", fields: HashMap::new(), - methods: HashMap::new(), - }, - TypeDef { + }), + primitive(BaseDef { name: "float", fields: HashMap::new(), - methods: HashMap::new(), - }, - ] - .to_vec(); - let mut ctx = GlobalContext::new(primitives); + }), + ]); - let b = ctx.get_primitive(BOOL_TYPE); - let b_def = ctx.get_primitive_def_mut(BOOL_TYPE); - impl_eq(b_def, &b); - let int32 = ctx.get_primitive(INT32_TYPE); - let int32_def = ctx.get_primitive_def_mut(INT32_TYPE); + let t = ctx.add_variable(VarDef { + name: Some("T"), + bound: vec![], + }); + ctx.get_type_def_mut(LIST_TYPE).params.push(t); + + let b_def = ctx.get_type_def_mut(BOOL_TYPE); + impl_eq(&mut b_def.base, &get_bool()); + let int32 = get_int32(); + let int32_def = &mut ctx.get_type_def_mut(INT32_TYPE).base; impl_math(int32_def, &int32); impl_bits(int32_def, &int32); impl_order(int32_def, &int32); impl_eq(int32_def, &int32); - let int64 = ctx.get_primitive(INT64_TYPE); - let int64_def = ctx.get_primitive_def_mut(INT64_TYPE); + let int64 = get_int64(); + let int64_def = &mut ctx.get_type_def_mut(INT64_TYPE).base; impl_math(int64_def, &int64); impl_bits(int64_def, &int64); impl_order(int64_def, &int64); impl_eq(int64_def, &int64); - let float = ctx.get_primitive(FLOAT_TYPE); - let float_def = ctx.get_primitive_def_mut(FLOAT_TYPE); + let float = get_float(); + let float_def = &mut ctx.get_type_def_mut(FLOAT_TYPE).base; impl_math(float_def, &float); impl_order(float_def, &float); impl_eq(float_def, &float); - let t = ctx.add_variable_private(VarDef { - name: "T", - bound: vec![], - }); - - ctx.add_parametric(ParametricDef { - base: TypeDef { - name: "tuple", - fields: HashMap::new(), - methods: HashMap::new(), - }, - // we have nothing for tuple, so no param def - params: vec![], - }); - - ctx.add_parametric(ParametricDef { - base: TypeDef { - name: "list", - fields: HashMap::new(), - methods: HashMap::new(), - }, - params: vec![t], - }); - - let i = ctx.add_variable_private(VarDef { - name: "I", - bound: vec![ - PrimitiveType(INT32_TYPE).into(), - PrimitiveType(INT64_TYPE).into(), - PrimitiveType(FLOAT_TYPE).into(), - ], - }); - let args = vec![TypeVariable(i).into()]; - ctx.add_fn( - "int32", - FnDef { - args: args.clone(), - result: Some(PrimitiveType(INT32_TYPE).into()), - }, - ); - ctx.add_fn( - "int64", - FnDef { - args: args.clone(), - result: Some(PrimitiveType(INT64_TYPE).into()), - }, - ); - ctx.add_fn( - "float", - FnDef { - args, - result: Some(PrimitiveType(FLOAT_TYPE).into()), - }, - ); - ctx } diff --git a/nac3core/src/typecheck/typedef.rs b/nac3core/src/typecheck/typedef.rs index bec61fd1..7c447b61 100644 --- a/nac3core/src/typecheck/typedef.rs +++ b/nac3core/src/typecheck/typedef.rs @@ -1,60 +1,64 @@ use std::collections::HashMap; +use std::collections::HashSet; use std::rc::Rc; -#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] -pub struct PrimitiveId(pub(crate) usize); - -#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] -pub struct ClassId(pub(crate) usize); - -#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] -pub struct ParamId(pub(crate) usize); - #[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] pub struct VariableId(pub(crate) usize); +#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)] +pub struct TypeId(pub(crate) usize); + #[derive(PartialEq, Eq, Clone, Hash, Debug)] pub enum TypeEnum { - BotType, - SelfType, - PrimitiveType(PrimitiveId), - ClassType(ClassId), - VirtualClassType(ClassId), - ParametricType(ParamId, Vec>), + ClassType(TypeId, Vec>), TypeVariable(VariableId), } pub type Type = Rc; #[derive(Clone)] -pub struct FnDef { - // we assume methods first argument to be SelfType, - // so the first argument is not contained here - pub args: Vec, - pub result: Option, +pub struct BaseDef<'a> { + pub name: &'a str, + pub fields: HashMap<&'a str, Type>, } #[derive(Clone)] pub struct TypeDef<'a> { - pub name: &'a str, - pub fields: HashMap<&'a str, Type>, - pub methods: HashMap<&'a str, FnDef>, -} - -#[derive(Clone)] -pub struct ClassDef<'a> { - pub base: TypeDef<'a>, - pub parents: Vec, -} - -#[derive(Clone)] -pub struct ParametricDef<'a> { - pub base: TypeDef<'a>, + pub base: BaseDef<'a>, + pub parents: Vec, pub params: Vec, } #[derive(Clone)] pub struct VarDef<'a> { - pub name: &'a str, + pub name: Option<&'a str>, pub bound: Vec, } + +impl TypeEnum { + pub fn get_vars(&self, vars: &mut HashSet) { + match self { + TypeEnum::TypeVariable(id) => { + vars.insert(*id); + } + TypeEnum::ClassType(_, params) => { + for t in params.iter() { + t.get_vars(vars) + } + } + } + } + + pub fn subst(&self, map: &HashMap) -> Type { + match self { + TypeEnum::TypeVariable(id) => map + .get(id) + .cloned() + .unwrap_or_else(|| Rc::new(self.clone())), + TypeEnum::ClassType(id, params) => Rc::new(TypeEnum::ClassType( + *id, + params.iter().map(|t| t.subst(map)).collect(), + )), + } + } +}