diff --git a/nac3core/src/context/top_level_context.rs b/nac3core/src/context/top_level_context.rs index b001418a..6f6f5f0d 100644 --- a/nac3core/src/context/top_level_context.rs +++ b/nac3core/src/context/top_level_context.rs @@ -18,21 +18,28 @@ pub struct TopLevelContext<'a> { pub(super) fn_table: HashMap<&'a str, FnDef>, /// Type name to type mapping. pub(super) sym_table: HashMap<&'a str, Type>, + + primitives: Vec, + variables: Vec, } impl<'a> TopLevelContext<'a> { - pub fn new(primitives: Vec>) -> TopLevelContext { + pub fn new(primitive_defs: Vec>) -> TopLevelContext { let mut sym_table = HashMap::new(); - for (i, t) in primitives.iter().enumerate() { + let mut primitives = Vec::new(); + for (i, t) in primitive_defs.iter().enumerate() { + primitives.push(TypeEnum::PrimitiveType(PrimitiveId(i)).into()); sym_table.insert(t.name, TypeEnum::PrimitiveType(PrimitiveId(i)).into()); } return TopLevelContext { - primitive_defs: primitives, + primitive_defs, class_defs: Vec::new(), parametric_defs: Vec::new(), var_defs: Vec::new(), fn_table: HashMap::new(), sym_table, + primitives, + variables: Vec::new(), }; } @@ -69,6 +76,8 @@ impl<'a> TopLevelContext<'a> { pub fn add_variable_private(&mut self, def: VarDef<'a>) -> VariableId { self.var_defs.push(def); + self.variables + .push(TypeEnum::TypeVariable(VariableId(self.var_defs.len() - 1)).into()); VariableId(self.var_defs.len() - 1) } @@ -76,42 +85,50 @@ impl<'a> TopLevelContext<'a> { self.fn_table.insert(name, def); } - pub fn get_fn(&self, name: &str) -> Option<&FnDef> { + pub fn get_fn_def(&self, name: &str) -> Option<&FnDef> { self.fn_table.get(name) } - pub fn get_primitive_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> { + pub fn get_primitive_def_mut(&mut self, id: PrimitiveId) -> &mut TypeDef<'a> { self.primitive_defs.get_mut(id.0).unwrap() } - pub fn get_primitive(&self, id: PrimitiveId) -> &TypeDef { + pub fn get_primitive_def(&self, id: PrimitiveId) -> &TypeDef { self.primitive_defs.get(id.0).unwrap() } - pub fn get_class_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> { + pub fn get_class_def_mut(&mut self, id: ClassId) -> &mut ClassDef<'a> { self.class_defs.get_mut(id.0).unwrap() } - pub fn get_class(&self, id: ClassId) -> &ClassDef { + pub fn get_class_def(&self, id: ClassId) -> &ClassDef { self.class_defs.get(id.0).unwrap() } - pub fn get_parametric_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> { + pub fn get_parametric_def_mut(&mut self, id: ParamId) -> &mut ParametricDef<'a> { self.parametric_defs.get_mut(id.0).unwrap() } - pub fn get_parametric(&self, id: ParamId) -> &ParametricDef { + pub fn get_parametric_def(&self, id: ParamId) -> &ParametricDef { self.parametric_defs.get(id.0).unwrap() } - pub fn get_variable_mut(&mut self, id: VariableId) -> &mut VarDef<'a> { + pub fn get_variable_def_mut(&mut self, id: VariableId) -> &mut VarDef<'a> { self.var_defs.get_mut(id.0).unwrap() } - pub fn get_variable(&self, id: VariableId) -> &VarDef { + pub fn get_variable_def(&self, id: VariableId) -> &VarDef { self.var_defs.get(id.0).unwrap() } + pub fn get_primitive(&self, id: PrimitiveId) -> Type { + self.primitives.get(id.0).unwrap().clone() + } + + pub fn get_variable(&self, id: VariableId) -> Type { + self.variables.get(id.0).unwrap().clone() + } + pub fn get_type(&self, name: &str) -> Option { // TODO: handle parametric types self.sym_table.get(name).map(|v| v.clone()) diff --git a/nac3core/src/lib.rs b/nac3core/src/lib.rs index 5a296655..ca0c5979 100644 --- a/nac3core/src/lib.rs +++ b/nac3core/src/lib.rs @@ -5,7 +5,7 @@ extern crate rustpython_parser; // pub mod expression; // pub mod inference; mod operators; -// pub mod primitives; +pub mod primitives; pub mod typedef; pub mod context; diff --git a/nac3core/src/primitives.rs b/nac3core/src/primitives.rs index 25b5cc3d..0d8e7eff 100644 --- a/nac3core/src/primitives.rs +++ b/nac3core/src/primitives.rs @@ -1,6 +1,6 @@ -use super::typedef::{Type::*, *}; +use super::typedef::{TypeEnum::*, *}; +use crate::context::*; use std::collections::HashMap; -use std::rc::Rc; pub const TUPLE_TYPE: ParamId = ParamId(0); pub const LIST_TYPE: ParamId = ParamId(1); @@ -10,7 +10,7 @@ pub const INT32_TYPE: PrimitiveId = PrimitiveId(1); pub const INT64_TYPE: PrimitiveId = PrimitiveId(2); pub const FLOAT_TYPE: PrimitiveId = PrimitiveId(3); -fn impl_math(def: &mut TypeDef, ty: &Rc) { +fn impl_math(def: &mut TypeDef, ty: &Type) { let result = Some(ty.clone()); let fun = FnDef { args: vec![ty.clone()], @@ -35,7 +35,7 @@ fn impl_math(def: &mut TypeDef, ty: &Rc) { def.methods.insert("__pow__", fun.clone()); } -fn impl_bits(def: &mut TypeDef, ty: &Rc) { +fn impl_bits(def: &mut TypeDef, ty: &Type) { let result = Some(ty.clone()); let fun = FnDef { args: vec![PrimitiveType(INT32_TYPE).into()], @@ -53,7 +53,7 @@ fn impl_bits(def: &mut TypeDef, ty: &Rc) { ); } -fn impl_eq(def: &mut TypeDef, ty: &Rc) { +fn impl_eq(def: &mut TypeDef, ty: &Type) { let fun = FnDef { args: vec![ty.clone()], result: Some(PrimitiveType(BOOL_TYPE).into()), @@ -63,7 +63,7 @@ fn impl_eq(def: &mut TypeDef, ty: &Rc) { def.methods.insert("__ne__", fun.clone()); } -fn impl_order(def: &mut TypeDef, ty: &Rc) { +fn impl_order(def: &mut TypeDef, ty: &Type) { let fun = FnDef { args: vec![ty.clone()], result: Some(PrimitiveType(BOOL_TYPE).into()), @@ -75,7 +75,7 @@ fn impl_order(def: &mut TypeDef, ty: &Rc) { def.methods.insert("__ge__", fun.clone()); } -pub fn basic_ctx() -> GlobalContext<'static> { +pub fn basic_ctx() -> TopLevelContext<'static> { let primitives = [ TypeDef { name: "bool", @@ -99,25 +99,25 @@ pub fn basic_ctx() -> GlobalContext<'static> { }, ] .to_vec(); - let mut ctx = GlobalContext::new(primitives); + let mut ctx = TopLevelContext::new(primitives); - let b_def = ctx.get_primitive_mut(BOOL_TYPE); - let b = PrimitiveType(BOOL_TYPE).into(); + let b = ctx.get_primitive(BOOL_TYPE); + let b_def = ctx.get_primitive_def_mut(BOOL_TYPE); impl_eq(b_def, &b); - let int32_def = ctx.get_primitive_mut(INT32_TYPE); - let int32 = PrimitiveType(INT32_TYPE).into(); + let int32 = ctx.get_primitive(INT32_TYPE); + let int32_def = ctx.get_primitive_def_mut(INT32_TYPE); impl_math(int32_def, &int32); impl_bits(int32_def, &int32); impl_order(int32_def, &int32); impl_eq(int32_def, &int32); - let int64_def = ctx.get_primitive_mut(INT64_TYPE); - let int64 = PrimitiveType(INT64_TYPE).into(); + let int64 = ctx.get_primitive(INT64_TYPE); + let int64_def = ctx.get_primitive_def_mut(INT64_TYPE); impl_math(int64_def, &int64); impl_bits(int64_def, &int64); impl_order(int64_def, &int64); impl_eq(int64_def, &int64); - let float_def = ctx.get_primitive_mut(FLOAT_TYPE); - let float = PrimitiveType(FLOAT_TYPE).into(); + let float = ctx.get_primitive(FLOAT_TYPE); + let float_def = ctx.get_primitive_def_mut(FLOAT_TYPE); impl_math(float_def, &float); impl_order(float_def, &float); impl_eq(float_def, &float);