diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index de9ff94..f78581f 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -49,7 +49,7 @@ pub struct CodeGenTask { pub signature: FunSignature, pub body: Vec>>, pub unifier_index: usize, - pub resolver: Arc, + pub resolver: Arc, } fn get_llvm_type<'ctx>( @@ -108,7 +108,7 @@ pub fn gen_func<'ctx>( module: Module<'ctx>, task: CodeGenTask, top_level_ctx: Arc, -) -> Module<'ctx> { +) -> (Builder<'ctx>, Module<'ctx>) { // unwrap_or(0) is for unit tests without using rayon let (mut unifier, primitives) = { let unifiers = top_level_ctx.unifiers.read(); @@ -199,5 +199,7 @@ pub fn gen_func<'ctx>( code_gen_context.gen_stmt(stmt); } - code_gen_context.module + let CodeGenContext { builder, module, .. } = code_gen_context; + + (builder, module) } diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index bcd5870..701d026 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -6,7 +6,7 @@ use crate::{ typecheck::{ magic_methods::set_primitives_magic_methods, type_inferencer::{CodeLocation, FunctionData, Inferencer, PrimitiveStore}, - typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}, + typedef::{CallId, FunSignature, FuncArg, Type, TypeEnum, Unifier}, }, }; use indoc::indoc; @@ -48,7 +48,7 @@ struct TestEnvironment { pub id_to_name: HashMap, pub identifier_mapping: HashMap, pub virtual_checks: Vec<(Type, Type)>, - pub calls: HashMap>, + pub calls: HashMap, pub top_level: TopLevelContext, } @@ -102,7 +102,7 @@ impl TestEnvironment { id_to_type: identifier_mapping.clone(), id_to_def: Default::default(), class_names: Default::default(), - }) as Arc; + }) as Arc; TestEnvironment { unifier, @@ -239,7 +239,7 @@ fn test_primitives() { } "} .trim(); - let ir = module.print_to_string().to_string(); + let ir = module.1.print_to_string().to_string(); println!("src:\n{}", source); println!("IR:\n{}", ir); assert_eq!(expected, ir.trim()); diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index e539ef1..5f716ee 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -161,19 +161,7 @@ pub fn parse_type_annotation( } } -impl dyn SymbolResolver { - pub fn parse_type_annotation( - &self, - top_level: &TopLevelContext, - unifier: &mut Unifier, - primitives: &PrimitiveStore, - expr: &Expr, - ) -> Result { - parse_type_annotation(self, top_level, unifier, primitives, expr) - } -} - -impl dyn SymbolResolver + Send { +impl dyn SymbolResolver + Send + Sync { pub fn parse_type_annotation( &self, top_level: &TopLevelContext, diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index cdd83a5..304431f 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -3,8 +3,8 @@ use std::convert::{From, TryInto}; use std::iter::once; use std::{cell::RefCell, sync::Arc}; -use super::magic_methods::*; use super::typedef::{Call, FunSignature, FuncArg, Type, TypeEnum, Unifier}; +use super::{magic_methods::*, typedef::CallId}; use crate::{symbol_resolver::SymbolResolver, top_level::TopLevelContext}; use itertools::izip; use rustpython_parser::ast::{ @@ -38,7 +38,7 @@ pub struct PrimitiveStore { } pub struct FunctionData { - pub resolver: Arc, + pub resolver: Arc, pub return_type: Option, pub bound_variables: Vec, } @@ -50,7 +50,7 @@ pub struct Inferencer<'a> { pub primitives: &'a PrimitiveStore, pub virtual_checks: &'a mut Vec<(Type, Type)>, pub variable_mapping: HashMap, - pub calls: &'a mut HashMap>, + pub calls: &'a mut HashMap, } struct NaiveFolder(); @@ -190,13 +190,13 @@ impl<'a> Inferencer<'a> { params: Vec, ret: Type, ) -> InferenceResult { - let call = Arc::new(Call { + let call = self.unifier.add_call(Call { posargs: params, kwargs: HashMap::new(), ret, fun: RefCell::new(None), }); - self.calls.insert(location.into(), call.clone()); + self.calls.insert(location.into(), call); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); let fields = once((method, call)).collect(); let record = self.unifier.add_record(fields); @@ -398,7 +398,7 @@ impl<'a> Inferencer<'a> { .map(|v| fold::fold_keyword(self, v)) .collect::, _>>()?; let ret = self.unifier.get_fresh_var().0; - let call = Arc::new(Call { + let call = self.unifier.add_call(Call { posargs: args.iter().map(|v| v.custom.unwrap()).collect(), kwargs: keywords .iter() @@ -407,7 +407,7 @@ impl<'a> Inferencer<'a> { fun: RefCell::new(None), ret, }); - self.calls.insert(location.into(), call.clone()); + self.calls.insert(location.into(), call); let call = self.unifier.add_ty(TypeEnum::TCall(vec![call].into())); self.unifier.unify(func.custom.unwrap(), call)?; diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 2101295..cebb0e8 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -40,7 +40,7 @@ struct TestEnvironment { pub id_to_name: HashMap, pub identifier_mapping: HashMap, pub virtual_checks: Vec<(Type, Type)>, - pub calls: HashMap>, + pub calls: HashMap, pub top_level: TopLevelContext, } @@ -94,7 +94,7 @@ impl TestEnvironment { id_to_type: identifier_mapping.clone(), id_to_def: Default::default(), class_names: Default::default(), - }) as Arc; + }) as Arc; TestEnvironment { top_level: TopLevelContext { @@ -273,7 +273,7 @@ impl TestEnvironment { .cloned() .collect(), class_names, - }) as Arc; + }) as Arc; TestEnvironment { unifier, diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index d2e75c8..6f0c34d 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -16,6 +16,9 @@ mod test; /// Handle for a type, implementated as a key in the unification table. pub type Type = UnificationKey; +#[derive(Clone, Copy, PartialEq, Eq)] +pub struct CallId(usize); + pub type Mapping = HashMap; type VarMap = Mapping; @@ -73,7 +76,7 @@ pub enum TypeEnum { TVirtual { ty: Type, }, - TCall(RefCell>>), + TCall(RefCell>), TFunc(FunSignature), } @@ -92,17 +95,18 @@ impl TypeEnum { } } -pub type SharedUnifier = Arc, u32)>>; +pub type SharedUnifier = Arc, u32, Vec)>>; pub struct Unifier { unification_table: UnificationTable>, + calls: Vec>, var_id: u32, } impl Unifier { /// Get an empty unifier pub fn new() -> Unifier { - Unifier { unification_table: UnificationTable::new(), var_id: 0 } + Unifier { unification_table: UnificationTable::new(), var_id: 0, calls: Vec::new() } } /// Determine if the two types are the same @@ -112,11 +116,19 @@ impl Unifier { pub fn from_shared_unifier(unifier: &SharedUnifier) -> Unifier { let lock = unifier.lock().unwrap(); - Unifier { unification_table: UnificationTable::from_send(&lock.0), var_id: lock.1 } + Unifier { + unification_table: UnificationTable::from_send(&lock.0), + var_id: lock.1, + calls: lock.2.iter().map(|v| Rc::new(v.clone())).collect_vec(), + } } pub fn get_shared_unifier(&self) -> SharedUnifier { - Arc::new(Mutex::new((self.unification_table.get_send(), self.var_id))) + Arc::new(Mutex::new(( + self.unification_table.get_send(), + self.var_id, + self.calls.iter().map(|v| v.as_ref().clone()).collect_vec(), + ))) } /// Register a type to the unifier. @@ -135,6 +147,12 @@ impl Unifier { }) } + pub fn add_call(&mut self, call: Call) -> CallId { + let id = CallId(self.calls.len()); + self.calls.push(Rc::new(call)); + id + } + pub fn get_representative(&mut self, ty: Type) -> Type { self.unification_table.get_representative(ty) } @@ -463,11 +481,11 @@ impl Unifier { .collect(); // we unify every calls to the function signature. for c in calls.borrow().iter() { - let Call { posargs, kwargs, ret, fun } = c.as_ref(); + let Call { posargs, kwargs, ret, fun } = &*self.calls[c.0].clone(); let instantiated = self.instantiate_fun(b, signature); - let signature; let r = self.get_ty(instantiated); let r = r.as_ref(); + let signature; if let TypeEnum::TFunc(s) = &*r { signature = s; } else { @@ -765,10 +783,14 @@ impl Unifier { } } TypeEnum::TCall(calls) => { + let call_store = self.calls.clone(); for t in calls .borrow() .iter() - .map(|call| chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret))) + .map(|call| { + let call = call_store[call.0].as_ref(); + chain!(call.posargs.iter(), call.kwargs.values(), once(&call.ret)) + }) .flatten() { self.occur_check(a, *t)?;