diff --git a/nac3core/src/context/inference_context.rs b/nac3core/src/context/inference_context.rs new file mode 100644 index 00000000..f321f2e1 --- /dev/null +++ b/nac3core/src/context/inference_context.rs @@ -0,0 +1,226 @@ +use super::TopLevelContext; +use crate::typedef::*; +use std::boxed::Box; +use std::collections::HashMap; + +struct ContextStack<'a> { + /// stack level, starts from 0 + level: u32, + /// stack of variable definitions containing (id, def, level) where `def` is the original + /// definition in `level-1`. + var_defs: Vec<(usize, VarDef<'a>, u32)>, + /// stack of symbol definitions containing (name, level) where `level` is the smallest level + /// where the name is assigned a value + sym_def: Vec<(&'a str, u32)>, +} + +pub struct InferenceContext<'a> { + /// top level context + top_level: TopLevelContext<'a>, + + /// list of primitive instances + primitives: Vec, + /// list of variable instances + variables: Vec, + /// identifier to (type, readable) mapping. + /// an identifier might be defined earlier but has no value (for some code path), thus not + /// readable. + sym_table: HashMap<&'a str, (Type, bool)>, + /// resolution function reference, that may resolve unbounded identifiers to some type + resolution_fn: Box Result>, + /// stack + stack: ContextStack<'a>, +} + +// non-trivial implementations here +impl<'a> InferenceContext<'a> { + /// return a new `InferenceContext` from `TopLevelContext` and resolution function. + pub fn new( + top_level: TopLevelContext, + resolution_fn: Box Result>, + ) -> InferenceContext { + let primitives = (0..top_level.primitive_defs.len()) + .map(|v| TypeEnum::PrimitiveType(PrimitiveId(v)).into()) + .collect(); + let variables = (0..top_level.var_defs.len()) + .map(|v| TypeEnum::TypeVariable(VariableId(v)).into()) + .collect(); + InferenceContext { + top_level, + primitives, + variables, + sym_table: HashMap::new(), + resolution_fn, + stack: ContextStack { + level: 0, + var_defs: Vec::new(), + sym_def: Vec::new(), + }, + } + } + + /// execute the function with new scope. + /// variable assignment would be limited within the scope (not readable outside), and type + /// variable type guard would be limited within the scope. + /// returns the list of variables assigned within the scope, and the result of the function + pub fn with_scope(&mut self, f: F) -> (Vec<&'a str>, R) + where + F: FnOnce(&mut Self) -> R, + { + self.stack.level += 1; + let result = f(self); + self.stack.level -= 1; + while self.stack.var_defs.len() > 0 { + let (_, _, level) = self.stack.var_defs.last().unwrap(); + if *level > self.stack.level { + let (id, def, _) = self.stack.var_defs.pop().unwrap(); + self.top_level.var_defs[id] = def; + } else { + break; + } + } + let mut poped_names = Vec::new(); + while self.stack.sym_def.len() > 0 { + let (_, level) = self.stack.sym_def.last().unwrap(); + if *level > self.stack.level { + let (name, _) = self.stack.sym_def.pop().unwrap(); + self.sym_table.get_mut(name).unwrap().1 = false; + poped_names.push(name); + } else { + break; + } + } + (poped_names, result) + } + + /// assign a type to an identifier. + /// may return error if the identifier was defined but with different type + pub fn assign(&mut self, name: &'a str, ty: Type) -> Result { + if let Some((t, x)) = self.sym_table.get_mut(name) { + if t == &ty { + if !*x { + self.stack.sym_def.push((name, self.stack.level)); + } + *x = true; + Ok(ty) + } else { + Err("different types".into()) + } + } else { + self.stack.sym_def.push((name, self.stack.level)); + self.sym_table.insert(name, (ty.clone(), true)); + Ok(ty) + } + } + + /// get the type of an identifier + /// may return error if the identifier is not defined, and cannot be resolved with the + /// resolution function. + pub fn resolve(&mut self, name: &'a str) -> Result { + if let Some((t, x)) = self.sym_table.get(name) { + if *x { + Ok(t.clone()) + } else { + Err("may not have value".into()) + } + } else { + self.resolution_fn.as_mut()(name) + } + } + + /// restrict the bound of a type variable by replacing its definition. + /// used for implementing type guard + pub fn restrict(&mut self, id: VariableId, mut def: VarDef<'a>) { + std::mem::swap(self.top_level.var_defs.get_mut(id.0).unwrap(), &mut def); + self.stack.var_defs.push((id.0, def, self.stack.level)); + } +} + +// trivial getters: +impl<'a> InferenceContext<'a> { + pub fn get_primitive(&self, id: PrimitiveId) -> Type { + self.primitives.get(id.0).unwrap().clone() + } + pub fn get_variable(&self, id: VariableId) -> Type { + self.variables.get(id.0).unwrap().clone() + } + + pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> { + self.top_level.fn_table.get(name) + } + pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef { + self.top_level.primitive_defs.get(id.0).unwrap() + } + pub fn get_class_def(&self, id: ClassId) -> &ClassDef { + self.top_level.class_defs.get(id.0).unwrap() + } + pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef { + self.top_level.parametric_defs.get(id.0).unwrap() + } + pub fn get_variable_def(&self, id: VariableId) -> &VarDef { + self.top_level.var_defs.get(id.0).unwrap() + } + pub fn get_type(&self, name: &str) -> Option { + self.top_level.get_type(name) + } +} + +impl TypeEnum { + pub fn subst(&self, map: &HashMap) -> TypeEnum { + match self { + TypeEnum::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(), + TypeEnum::ParametricType(id, params) => TypeEnum::ParametricType( + *id, + params + .iter() + .map(|v| v.as_ref().subst(map).into()) + .collect(), + ), + _ => self.clone(), + } + } + + pub fn inv_subst(&self, map: &[(Type, Type)]) -> Type { + for (from, to) in map.iter() { + if self == from.as_ref() { + return to.clone(); + } + } + match self { + TypeEnum::ParametricType(id, params) => TypeEnum::ParametricType( + *id, + params + .iter() + .map(|v| v.as_ref().inv_subst(map).into()) + .collect(), + ), + _ => self.clone(), + } + .into() + } + + pub fn get_subst(&self, ctx: &InferenceContext) -> HashMap { + match self { + TypeEnum::ParametricType(id, params) => { + let vars = &ctx.get_parametric_def(*id).params; + vars.iter() + .zip(params) + .map(|(v, p)| (*v, p.as_ref().clone().into())) + .collect() + } + // if this proves to be slow, we can use option type + _ => HashMap::new(), + } + } + + pub fn get_base<'b: 'a, 'a>(&'a self, ctx: &'b InferenceContext) -> Option<&'b TypeDef> { + match self { + TypeEnum::PrimitiveType(id) => Some(ctx.get_primitive_def(*id)), + TypeEnum::ClassType(id) | TypeEnum::VirtualClassType(id) => { + Some(&ctx.get_class_def(*id).base) + } + TypeEnum::ParametricType(id, _) => Some(&ctx.get_parametric_def(*id).base), + _ => None, + } + } +} diff --git a/nac3core/src/context/mod.rs b/nac3core/src/context/mod.rs new file mode 100644 index 00000000..88e2a43a --- /dev/null +++ b/nac3core/src/context/mod.rs @@ -0,0 +1,5 @@ +mod top_level_context; +mod inference_context; +pub use top_level_context::TopLevelContext; +pub use inference_context::InferenceContext; + diff --git a/nac3core/src/context/top_level_context.rs b/nac3core/src/context/top_level_context.rs new file mode 100644 index 00000000..b001418a --- /dev/null +++ b/nac3core/src/context/top_level_context.rs @@ -0,0 +1,119 @@ +use crate::typedef::*; +use std::collections::HashMap; +use std::rc::Rc; + +/// Structure for storing top-level type definitions. +/// Used for collecting type signature from source code. +/// Can be converted to `InferenceContext` for type inference in functions. +pub struct TopLevelContext<'a> { + /// List of primitive definitions. + pub(super) primitive_defs: Vec>, + /// List of class definitions. + pub(super) class_defs: Vec>, + /// List of parametric type definitions. + pub(super) parametric_defs: Vec>, + /// List of type variable definitions. + pub(super) var_defs: Vec>, + /// Function name to signature mapping. + pub(super) fn_table: HashMap<&'a str, FnDef>, + /// Type name to type mapping. + pub(super) sym_table: HashMap<&'a str, Type>, +} + +impl<'a> TopLevelContext<'a> { + pub fn new(primitives: Vec>) -> TopLevelContext { + let mut sym_table = HashMap::new(); + for (i, t) in primitives.iter().enumerate() { + sym_table.insert(t.name, TypeEnum::PrimitiveType(PrimitiveId(i)).into()); + } + return TopLevelContext { + primitive_defs: primitives, + class_defs: Vec::new(), + parametric_defs: Vec::new(), + var_defs: Vec::new(), + fn_table: HashMap::new(), + sym_table, + }; + } + + pub fn add_class(&mut self, def: ClassDef<'a>) -> ClassId { + self.sym_table.insert( + def.base.name, + TypeEnum::ClassType(ClassId(self.class_defs.len())).into(), + ); + self.class_defs.push(def); + ClassId(self.class_defs.len() - 1) + } + + pub fn add_parametric(&mut self, def: ParametricDef<'a>) -> ParamId { + let params = def + .params + .iter() + .map(|&v| Rc::new(TypeEnum::TypeVariable(v))) + .collect(); + self.sym_table.insert( + def.base.name, + TypeEnum::ParametricType(ParamId(self.parametric_defs.len()), params).into(), + ); + self.parametric_defs.push(def); + ParamId(self.parametric_defs.len() - 1) + } + + pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId { + self.sym_table.insert( + def.name, + TypeEnum::TypeVariable(VariableId(self.var_defs.len())).into(), + ); + self.add_variable_private(def) + } + + pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId { + self.var_defs.push(def); + VariableId(self.var_defs.len() - 1) + } + + pub fn add_fn(&mut self, name: &'a str, def: FnDef) { + self.fn_table.insert(name, def); + } + + pub fn get_fn(&self, name: &str) -> Option<&FnDef> { + self.fn_table.get(name) + } + + pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> { + self.primitive_defs.get_mut(id.0).unwrap() + } + + pub fn get_primitive(&self, id: PrimitiveId) -> &TypeDef { + self.primitive_defs.get(id.0).unwrap() + } + + pub fn get_class_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> { + self.class_defs.get_mut(id.0).unwrap() + } + + pub fn get_class(&self, id: ClassId) -> &ClassDef { + self.class_defs.get(id.0).unwrap() + } + + pub fn get_parametric_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> { + self.parametric_defs.get_mut(id.0).unwrap() + } + + pub fn get_parametric(&self, id: ParamId) -> &ParametricDef { + self.parametric_defs.get(id.0).unwrap() + } + + pub fn get_variable_mut(&mut self, id: VariableId) -> &mut VarDef<'a> { + self.var_defs.get_mut(id.0).unwrap() + } + + pub fn get_variable(&self, id: VariableId) -> &VarDef { + self.var_defs.get(id.0).unwrap() + } + + pub fn get_type(&self, name: &str) -> Option { + // TODO: handle parametric types + self.sym_table.get(name).map(|v| v.clone()) + } +} diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index ab521b4a..5a296655 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -2,11 +2,12 @@ extern crate num_bigint; extern crate inkwell; extern crate rustpython_parser; -pub mod expression; -pub mod inference; +// pub mod expression; +// pub mod inference; mod operators; -pub mod primitives; +// pub mod primitives; pub mod typedef; +pub mod context; use std::error::Error; use std::fmt; diff --git a/nac3core/src/typedef.rs b/nac3core/src/typedef.rs index 5fe42279..bec61fd1 100644 --- a/nac3core/src/typedef.rs +++ b/nac3core/src/typedef.rs @@ -14,28 +14,30 @@ pub struct ParamId(pub(crate) usize); pub struct VariableId(pub(crate) usize); #[derive(PartialEq, Eq, Clone, Hash, Debug)] -pub enum Type { +pub enum TypeEnum { BotType, SelfType, PrimitiveType(PrimitiveId), ClassType(ClassId), VirtualClassType(ClassId), - ParametricType(ParamId, Vec>), + ParametricType(ParamId, Vec>), TypeVariable(VariableId), } +pub type Type = Rc; + #[derive(Clone)] pub struct FnDef { // we assume methods first argument to be SelfType, // so the first argument is not contained here - pub args: Vec>, - pub result: Option>, + pub args: Vec, + pub result: Option, } #[derive(Clone)] pub struct TypeDef<'a> { pub name: &'a str, - pub fields: HashMap<&'a str, Rc>, + pub fields: HashMap<&'a str, Type>, pub methods: HashMap<&'a str, FnDef>, } @@ -54,170 +56,5 @@ pub struct ParametricDef<'a> { #[derive(Clone)] pub struct VarDef<'a> { pub name: &'a str, - pub bound: Vec>, -} - -pub struct GlobalContext<'a> { - primitive_defs: Vec>, - class_defs: Vec>, - parametric_defs: Vec>, - var_defs: Vec>, - sym_table: HashMap<&'a str, Type>, - fn_table: HashMap<&'a str, FnDef>, -} - -impl<'a> GlobalContext<'a> { - pub fn new(primitives: Vec>) -> GlobalContext { - let mut sym_table = HashMap::new(); - for (i, t) in primitives.iter().enumerate() { - sym_table.insert(t.name, Type::PrimitiveType(PrimitiveId(i))); - } - return GlobalContext { - primitive_defs: primitives, - class_defs: Vec::new(), - parametric_defs: Vec::new(), - var_defs: Vec::new(), - fn_table: HashMap::new(), - sym_table, - }; - } - - pub fn add_class(&mut self, def: ClassDef<'a>) -> ClassId { - self.sym_table.insert( - def.base.name, - Type::ClassType(ClassId(self.class_defs.len())), - ); - self.class_defs.push(def); - ClassId(self.class_defs.len() - 1) - } - - pub fn add_parametric(&mut self, def: ParametricDef<'a>) -> ParamId { - let params = def - .params - .iter() - .map(|&v| Rc::new(Type::TypeVariable(v))) - .collect(); - self.sym_table.insert( - def.base.name, - Type::ParametricType(ParamId(self.parametric_defs.len()), params), - ); - self.parametric_defs.push(def); - ParamId(self.parametric_defs.len() - 1) - } - - pub fn add_variable(&mut self, def: VarDef<'a>) -> VariableId { - self.sym_table.insert( - def.name, - Type::TypeVariable(VariableId(self.var_defs.len())), - ); - self.add_variable_private(def) - } - - pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId { - self.var_defs.push(def); - VariableId(self.var_defs.len() - 1) - } - - pub fn add_fn(&mut self, name: &'a str, def: FnDef) { - self.fn_table.insert(name, def); - } - - pub fn get_fn(&self, name: &str) -> Option<&FnDef> { - self.fn_table.get(name) - } - - pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> { - self.primitive_defs.get_mut(id.0).unwrap() - } - - pub fn get_primitive(&self, id: PrimitiveId) -> &TypeDef { - self.primitive_defs.get(id.0).unwrap() - } - - pub fn get_class_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> { - self.class_defs.get_mut(id.0).unwrap() - } - - pub fn get_class(&self, id: ClassId) -> &ClassDef { - self.class_defs.get(id.0).unwrap() - } - - pub fn get_parametric_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> { - self.parametric_defs.get_mut(id.0).unwrap() - } - - pub fn get_parametric(&self, id: ParamId) -> &ParametricDef { - self.parametric_defs.get(id.0).unwrap() - } - - pub fn get_variable_mut(&mut self, id: VariableId) -> &mut VarDef<'a> { - self.var_defs.get_mut(id.0).unwrap() - } - - pub fn get_variable(&self, id: VariableId) -> &VarDef { - self.var_defs.get(id.0).unwrap() - } - - pub fn get_type(&self, name: &str) -> Option { - // TODO: change this to handle import - self.sym_table.get(name).map(|v| v.clone()) - } -} - -impl Type { - pub fn subst(&self, map: &HashMap>) -> Type { - match self { - Type::TypeVariable(id) => map.get(id).map(|v| v.as_ref()).unwrap_or(self).clone(), - Type::ParametricType(id, params) => Type::ParametricType( - *id, - params - .iter() - .map(|v| v.as_ref().subst(map).into()) - .collect(), - ), - _ => self.clone(), - } - } - - pub fn inv_subst(&self, map: &[(Rc, Rc)]) -> Rc { - for (from, to) in map.iter() { - if self == from.as_ref() { - return to.clone(); - } - } - match self { - Type::ParametricType(id, params) => Type::ParametricType( - *id, - params - .iter() - .map(|v| v.as_ref().inv_subst(map).into()) - .collect(), - ), - _ => self.clone(), - } - .into() - } - - pub fn get_subst(&self, ctx: &GlobalContext) -> HashMap> { - match self { - Type::ParametricType(id, params) => { - let vars = &ctx.get_parametric(*id).params; - vars.iter() - .zip(params) - .map(|(v, p)| (*v, p.as_ref().clone().into())) - .collect() - } - // if this proves to be slow, we can use option type - _ => HashMap::new(), - } - } - - pub fn get_base<'b: 'a, 'a>(&'a self, ctx: &'b GlobalContext) -> Option<&'b TypeDef> { - match self { - Type::PrimitiveType(id) => Some(ctx.get_primitive(*id)), - Type::ClassType(id) | Type::VirtualClassType(id) => Some(&ctx.get_class(*id).base), - Type::ParametricType(id, _) => Some(&ctx.get_parametric(*id).base), - _ => None, - } - } + pub bound: Vec, }