diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 036aee97..013785e5 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -109,105 +109,104 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { params: Vec<(Option, BasicValueEnum<'ctx>)>, ) -> Option> { let key = self.get_subst_key(obj.map(|a| a.0), fun.0, None); - let top_level_defs = self.top_level.definitions.read(); - let definition = top_level_defs.get(fun.1 .0).unwrap(); - let symbol = - if let TopLevelDef::Function { instance_to_symbol, .. } = &*definition.read() { + let definition = self.top_level.definitions.read().get(fun.1.0).cloned().unwrap(); + let mut task = None; + let symbol = { + // make sure this lock guard is dropped at the end of this scope... + let def = definition.read(); + if let TopLevelDef::Function { instance_to_symbol, .. } = &*def { instance_to_symbol.get(&key).cloned() } else { unreachable!() } - .unwrap_or_else(|| { - if let TopLevelDef::Function { - name, - instance_to_symbol, - instance_to_stmt, - var_id, - resolver, - .. - } = &mut *definition.write() - { - instance_to_symbol.get(&key).cloned().unwrap_or_else(|| { - let symbol = format!("{}_{}", name, instance_to_symbol.len()); - instance_to_symbol.insert(key, symbol.clone()); - let key = self.get_subst_key(obj.map(|a| a.0), fun.0, Some(var_id)); - let instance = instance_to_stmt.get(&key).unwrap(); - let unifiers = self.top_level.unifiers.read(); - let (unifier, primitives) = &unifiers[instance.unifier_id]; - let mut unifier = Unifier::from_shared_unifier(&unifier); + } + .unwrap_or_else(|| { + if let TopLevelDef::Function { + name, + instance_to_symbol, + instance_to_stmt, + var_id, + resolver, + .. + } = &mut *definition.write() + { + instance_to_symbol.get(&key).cloned().unwrap_or_else(|| { + let symbol = format!("{}_{}", name, instance_to_symbol.len()); + instance_to_symbol.insert(key, symbol.clone()); + let key = self.get_subst_key(obj.map(|a| a.0), fun.0, Some(var_id)); + let instance = instance_to_stmt.get(&key).unwrap(); + let unifiers = self.top_level.unifiers.read(); + let (unifier, primitives) = &unifiers[instance.unifier_id]; + let mut unifier = Unifier::from_shared_unifier(&unifier); - let mut type_cache = [ - (self.primitives.int32, primitives.int32), - (self.primitives.int64, primitives.int64), - (self.primitives.float, primitives.float), - (self.primitives.bool, primitives.bool), - (self.primitives.none, primitives.none), - ] + let mut type_cache = [ + (self.primitives.int32, primitives.int32), + (self.primitives.int64, primitives.int64), + (self.primitives.float, primitives.float), + (self.primitives.bool, primitives.bool), + (self.primitives.none, primitives.none), + ] + .iter() + .map(|(a, b)| { + (self.unifier.get_representative(*a), unifier.get_representative(*b)) + }) + .collect(); + + let subst = fun + .0 + .vars .iter() - .map(|(a, b)| { - (self.unifier.get_representative(*a), unifier.get_representative(*b)) + .map(|(id, ty)| { + ( + *instance.subst.get(id).unwrap(), + unifier.copy_from(&mut self.unifier, *ty, &mut type_cache), + ) }) .collect(); - let subst = fun + let signature = FunSignature { + args: fun + .0 + .args + .iter() + .map(|arg| FuncArg { + name: arg.name.clone(), + ty: unifier.copy_from(&mut self.unifier, arg.ty, &mut type_cache), + default_value: arg.default_value.clone(), + }) + .collect(), + ret: unifier.copy_from(&mut self.unifier, fun.0.ret, &mut type_cache), + vars: fun .0 .vars .iter() .map(|(id, ty)| { - ( - *instance.subst.get(id).unwrap(), - unifier.copy_from(&mut self.unifier, *ty, &mut type_cache), - ) + (*id, unifier.copy_from(&mut self.unifier, *ty, &mut type_cache)) }) - .collect(); + .collect(), + }; - let signature = FunSignature { - args: fun - .0 - .args - .iter() - .map(|arg| FuncArg { - name: arg.name.clone(), - ty: unifier.copy_from( - &mut self.unifier, - arg.ty, - &mut type_cache, - ), - default_value: arg.default_value.clone(), - }) - .collect(), - ret: unifier.copy_from(&mut self.unifier, fun.0.ret, &mut type_cache), - vars: fun - .0 - .vars - .iter() - .map(|(id, ty)| { - ( - *id, - unifier.copy_from(&mut self.unifier, *ty, &mut type_cache), - ) - }) - .collect(), - }; + let unifier = (unifier.get_shared_unifier(), *primitives); - let unifier = (unifier.get_shared_unifier(), *primitives); + task = Some(CodeGenTask { + symbol_name: symbol.clone(), + body: instance.body.clone(), + resolver: resolver.as_ref().unwrap().clone(), + calls: instance.calls.clone(), + subst, + signature, + unifier, + }); + symbol + }) + } else { + unreachable!() + } + }); - let task = CodeGenTask { - symbol_name: symbol.clone(), - body: instance.body.clone(), - resolver: resolver.as_ref().unwrap().clone(), - calls: instance.calls.clone(), - subst, - signature, - unifier, - }; - self.registry.add_task(task); - symbol - }) - } else { - unreachable!() - } - }); + if let Some(task) = task { + self.registry.add_task(task); + } let fun_val = self.module.get_function(&symbol).unwrap_or_else(|| { let params = fun.0.args.iter().map(|arg| self.get_llvm_type(arg.ty)).collect_vec(); diff --git a/nac3core/src/codegen/test.rs b/nac3core/src/codegen/test.rs index 978bf414..cde701b2 100644 --- a/nac3core/src/codegen/test.rs +++ b/nac3core/src/codegen/test.rs @@ -1,26 +1,26 @@ -use crate::{ - codegen::{CodeGenTask, WithCall, WorkerRegistry}, - location::Location, - symbol_resolver::{SymbolResolver, SymbolValue}, - toplevel::{DefinitionId, TopLevelComposer, TopLevelContext}, - typecheck::{ +use crate::{codegen::{CodeGenTask, WithCall, WorkerRegistry}, location::Location, symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{DefinitionId, FunInstance, TopLevelComposer, TopLevelContext, TopLevelDef}, typecheck::{ type_inferencer::{FunctionData, Inferencer, PrimitiveStore}, - typedef::{FunSignature, FuncArg, Type, Unifier}, - }, -}; + typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, + }}; use indoc::indoc; use parking_lot::RwLock; use rustpython_parser::{ast::fold::Fold, parser::parse_program}; use std::collections::HashMap; use std::sync::Arc; +use std::cell::RefCell; -#[derive(Clone)] struct Resolver { id_to_type: HashMap, - id_to_def: HashMap, + id_to_def: RwLock>, class_names: HashMap, } +impl Resolver { + pub fn add_id_def(&self, id: String, def: DefinitionId) { + self.id_to_def.write().insert(id, def); + } +} + impl SymbolResolver for Resolver { fn get_symbol_type(&self, _: &mut Unifier, _: &PrimitiveStore, str: &str) -> Option { self.id_to_type.get(str).cloned() @@ -35,7 +35,7 @@ impl SymbolResolver for Resolver { } fn get_identifier_def(&self, id: &str) -> Option { - self.id_to_def.get(id).cloned() + self.id_to_def.read().get(id).cloned() } } @@ -56,7 +56,7 @@ fn test_primitives() { let resolver = Arc::new(Box::new(Resolver { id_to_type: HashMap::new(), - id_to_def: HashMap::new(), + id_to_def: RwLock::new(HashMap::new()), class_names: Default::default(), }) as Box); @@ -175,5 +175,160 @@ fn test_primitives() { let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); registry.add_task(task); registry.wait_tasks_complete(handles); - println!("object file is in mandelbrot.o") +} + +#[test] +fn test_simple_call() { + let source_1 = indoc! { " + a = foo(a) + return a * 2 + "}; + let statements_1 = parse_program(source_1).unwrap(); + + let source_2 = indoc! { " + return a + 1 + "}; + let statements_2 = parse_program(source_2).unwrap(); + + let (_, composer) = TopLevelComposer::new(); + let mut unifier = composer.unifier.clone(); + let primitives = composer.primitives_ty; + let top_level = Arc::new(composer.make_top_level_context()); + unifier.top_level = Some(top_level.clone()); + + let signature = FunSignature { + args: vec![FuncArg { name: "a".to_string(), ty: primitives.int32, default_value: None }], + ret: primitives.int32, + vars: HashMap::new(), + }; + let fun_ty = unifier.add_ty(TypeEnum::TFunc(RefCell::new(signature.clone()))); + + let foo_id = top_level.definitions.read().len(); + top_level.definitions.write().push(Arc::new(RwLock::new(TopLevelDef::Function { + name: "foo".to_string(), + signature: fun_ty, + var_id: vec![], + instance_to_stmt: HashMap::new(), + instance_to_symbol: HashMap::new(), + resolver: None + }))); + + let resolver = Box::new(Resolver { + id_to_type: HashMap::new(), + id_to_def: RwLock::new(HashMap::new()), + class_names: Default::default(), + }); + resolver.add_id_def("foo".to_string(), DefinitionId(foo_id)); + let resolver = Arc::new(resolver as Box); + + if let TopLevelDef::Function {resolver: r, ..} = &mut *top_level.definitions.read()[foo_id].write() { + *r = Some(resolver.clone()); + } else { + unreachable!() + } + + let threads = ["test"]; + let mut function_data = FunctionData { + resolver: resolver.clone(), + bound_variables: Vec::new(), + return_type: Some(primitives.int32), + }; + let mut virtual_checks = Vec::new(); + let mut calls = HashMap::new(); + let mut inferencer = Inferencer { + top_level: &top_level, + function_data: &mut function_data, + unifier: &mut unifier, + variable_mapping: Default::default(), + primitives: &primitives, + virtual_checks: &mut virtual_checks, + calls: &mut calls, + }; + inferencer.variable_mapping.insert("a".into(), inferencer.primitives.int32); + inferencer.variable_mapping.insert("foo".into(), fun_ty); + + let statements_1 = statements_1 + .into_iter() + .map(|v| inferencer.fold_stmt(v)) + .collect::, _>>() + .unwrap(); + + let calls1 = inferencer.calls.clone(); + inferencer.calls.clear(); + + let statements_2 = statements_2 + .into_iter() + .map(|v| inferencer.fold_stmt(v)) + .collect::, _>>() + .unwrap(); + + if let TopLevelDef::Function {instance_to_stmt, ..} = &mut *top_level.definitions.read()[foo_id].write() { + instance_to_stmt.insert("".to_string(), FunInstance { + body: statements_2, + calls: inferencer.calls.clone(), + subst: Default::default(), + unifier_id: 0 + }); + } else { + unreachable!() + } + + let mut identifiers = vec!["a".to_string(), "foo".into()]; + inferencer.check_block(&statements_1, &mut identifiers).unwrap(); + let top_level = Arc::new(TopLevelContext { + definitions: Arc::new(RwLock::new(std::mem::take(&mut *top_level.definitions.write()))), + unifiers: Arc::new(RwLock::new(vec![(unifier.get_shared_unifier(), primitives)])), + }); + + let unifier = (unifier.get_shared_unifier(), primitives); + + let task = CodeGenTask { + subst: Default::default(), + symbol_name: "testing".to_string(), + body: statements_1, + resolver, + unifier, + calls: calls1, + signature, + }; + let f = Arc::new(WithCall::new(Box::new(|module| { + let expected = indoc! {" + ; ModuleID = 'test' + source_filename = \"test\" + + define i32 @testing(i32 %0) { + init: + %a = alloca i32, align 4 + store i32 %0, i32* %a, align 4 + br label %body + + body: ; preds = %init + %load = load i32, i32* %a, align 4 + %call = call i32 @foo_0(i32 %load) + store i32 %call, i32* %a, align 4 + %load1 = load i32, i32* %a, align 4 + %mul = mul i32 %load1, 2 + ret i32 %mul + } + + declare i32 @foo_0(i32) + + define i32 @foo_0.1(i32 %0) { + init: + %a = alloca i32, align 4 + store i32 %0, i32* %a, align 4 + br label %body + + body: ; preds = %init + %load = load i32, i32* %a, align 4 + %add = add i32 %load, 1 + ret i32 %add + } + "} + .trim(); + assert_eq!(expected, module.print_to_string().to_str().unwrap().trim()); + }))); + let (registry, handles) = WorkerRegistry::create_workers(&threads, top_level, f); + registry.add_task(task); + registry.wait_tasks_complete(handles); }