diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index f852fab9..f002eff9 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1,6 +1,6 @@ use std::cell::RefCell; use std::collections::HashMap; -use std::convert::TryInto; +use std::convert::{TryInto, From}; use std::iter::once; use std::rc::Rc; @@ -17,6 +17,21 @@ use rustpython_parser::ast::{ #[cfg(test)] mod test; +#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)] +pub struct CodeLocation { + row: usize, + col: usize, +} + +impl From for CodeLocation { + fn from(loc: Location) -> CodeLocation { + CodeLocation { + row: loc.row(), + col: loc.column() + } + } +} + pub struct PrimitiveStore { pub int32: Type, pub int64: Type, @@ -37,6 +52,7 @@ pub struct Inferencer<'a> { pub primitives: &'a PrimitiveStore, pub virtual_checks: &'a mut Vec<(Type, Type)>, pub variable_mapping: HashMap, + pub calls: &'a mut HashMap>, } struct NaiveFolder(); @@ -215,6 +231,7 @@ impl<'a> Inferencer<'a> { unifier: self.unifier, primitives: self.primitives, virtual_checks: self.virtual_checks, + calls: self.calls, variable_mapping, }; let fun = FunSignature { @@ -257,6 +274,7 @@ impl<'a> Inferencer<'a> { virtual_checks: self.virtual_checks, variable_mapping, primitives: self.primitives, + calls: self.calls, }; let elt = new_context.fold_expr(elt)?; let generator = generators.pop().unwrap(); @@ -379,6 +397,7 @@ impl<'a> Inferencer<'a> { fun: RefCell::new(None), ret, }); + self.calls.insert(location.into(), call.clone()); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); self.unifier.unify(func.custom.unwrap(), call)?; diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 16a5ffff..c1bceed5 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -42,6 +42,7 @@ struct TestEnvironment { pub id_to_name: HashMap, pub identifier_mapping: HashMap, pub virtual_checks: Vec<(Type, Type)>, + pub calls: HashMap>, } impl TestEnvironment { @@ -151,12 +152,13 @@ impl TestEnvironment { function_data: FunctionData { resolver, bound_variables: Vec::new(), - return_type: None + return_type: None, }, primitives, id_to_name, identifier_mapping, virtual_checks: Vec::new(), + calls: HashMap::new(), } } @@ -167,6 +169,7 @@ impl TestEnvironment { variable_mapping: Default::default(), primitives: &mut self.primitives, virtual_checks: &mut self.virtual_checks, + calls: &mut self.calls, } } }