From 7cb9be0f811f43de34b6c02f8fad0e26654eadaf Mon Sep 17 00:00:00 2001
From: occheung <dc@m-labs.hk>
Date: Tue, 31 May 2022 15:12:18 +0800
Subject: [PATCH] nac3artiq: refactor compile methods

Avoids writing relocatable object to a file when linking with nac3ld.
---
 nac3artiq/src/lib.rs | 567 ++++++++++++++++++++++++-------------------
 1 file changed, 317 insertions(+), 250 deletions(-)

diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs
index 4f0b49f..855b133 100644
--- a/nac3artiq/src/lib.rs
+++ b/nac3artiq/src/lib.rs
@@ -7,7 +7,7 @@ use std::sync::Arc;
 
 use inkwell::{
     memory_buffer::MemoryBuffer,
-    module::Linkage,
+    module::{Linkage, Module},
     passes::{PassManager, PassManagerBuilder},
     targets::*,
     OptimizationLevel,
@@ -266,213 +266,16 @@ impl Nac3 {
         }
         None
     }
-}
 
-fn add_exceptions(
-    composer: &mut TopLevelComposer,
-    builtin_def: &mut HashMap<StrRef, DefinitionId>,
-    builtin_ty: &mut HashMap<StrRef, Type>,
-    error_names: &[&str]
-) -> Vec<Type> {
-    let mut types = Vec::new();
-    // note: this is only for builtin exceptions, i.e. the exception name is "0:{exn}"
-    for name in error_names {
-        let def_id = composer.definition_ast_list.len();
-        let (exception_fn, exception_class, exception_cons, exception_type) = get_exn_constructor(
-            name,
-            // class id
-            def_id,
-            // constructor id
-            def_id + 1,
-            &mut composer.unifier,
-            &composer.primitives_ty
-        );
-        composer.definition_ast_list.push((Arc::new(RwLock::new(exception_class)), None));
-        composer.definition_ast_list.push((Arc::new(RwLock::new(exception_fn)), None));
-        builtin_ty.insert((*name).into(), exception_cons);
-        builtin_def.insert((*name).into(), DefinitionId(def_id));
-        types.push(exception_type);
-    }
-    types
-}
-
-#[pymethods]
-impl Nac3 {
-    #[new]
-    fn new(isa: &str, py: Python) -> PyResult<Self> {
-        let isa = match isa {
-            "host" => Isa::Host,
-            "rv32g" => Isa::RiscV32G,
-            "rv32ima" => Isa::RiscV32IMA,
-            "cortexa9" => Isa::CortexA9,
-            _ => return Err(exceptions::PyValueError::new_err("invalid ISA")),
-        };
-        let time_fns: &(dyn TimeFns + Sync) = match isa {
-            Isa::Host => &timeline::EXTERN_TIME_FNS,
-            Isa::RiscV32G => &timeline::NOW_PINNING_TIME_FNS_64,
-            Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
-            Isa::CortexA9 => &timeline::EXTERN_TIME_FNS,
-        };
-        let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0;
-        let builtins = vec![
-            (
-                "now_mu".into(),
-                FunSignature { args: vec![], ret: primitive.int64, vars: HashMap::new() },
-                Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| {
-                    Ok(Some(time_fns.emit_now_mu(ctx)))
-                }))),
-            ),
-            (
-                "at_mu".into(),
-                FunSignature {
-                    args: vec![FuncArg {
-                        name: "t".into(),
-                        ty: primitive.int64,
-                        default_value: None,
-                    }],
-                    ret: primitive.none,
-                    vars: HashMap::new(),
-                },
-                Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
-                    let arg_ty = fun.0.args[0].ty;
-                    let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
-                    time_fns.emit_at_mu(ctx, arg);
-                    Ok(None)
-                }))),
-            ),
-            (
-                "delay_mu".into(),
-                FunSignature {
-                    args: vec![FuncArg {
-                        name: "dt".into(),
-                        ty: primitive.int64,
-                        default_value: None,
-                    }],
-                    ret: primitive.none,
-                    vars: HashMap::new(),
-                },
-                Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
-                    let arg_ty = fun.0.args[0].ty;
-                    let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
-                    time_fns.emit_delay_mu(ctx, arg);
-                    Ok(None)
-                }))),
-            ),
-        ];
-
-        let builtins_mod = PyModule::import(py, "builtins").unwrap();
-        let id_fn = builtins_mod.getattr("id").unwrap();
-        let numpy_mod = PyModule::import(py, "numpy").unwrap();
-        let typing_mod = PyModule::import(py, "typing").unwrap();
-        let types_mod = PyModule::import(py, "types").unwrap();
-
-        let get_id = |x| id_fn.call1((x,)).unwrap().extract().unwrap();
-        let get_attr_id = |obj: &PyModule, attr| id_fn.call1((obj.getattr(attr).unwrap(),))
-            .unwrap().extract().unwrap();
-        let primitive_ids = PrimitivePythonId {
-            virtual_id: get_id(
-                            builtins_mod
-                            .getattr("globals")
-                            .unwrap()
-                            .call0()
-                            .unwrap()
-                            .get_item("virtual")
-                            .unwrap(
-                        )),
-            generic_alias: (
-                get_attr_id(typing_mod, "_GenericAlias"),
-                get_attr_id(types_mod, "GenericAlias"),
-            ),
-            none: id_fn
-                .call1((builtins_mod
-                    .getattr("globals")
-                    .unwrap()
-                    .call0()
-                    .unwrap()
-                    .get_item("none")
-                    .unwrap(),))
-                .unwrap()
-                .extract()
-                .unwrap(),
-            typevar: get_attr_id(typing_mod, "TypeVar"),
-            int: get_attr_id(builtins_mod, "int"),
-            int32: get_attr_id(numpy_mod, "int32"),
-            int64: get_attr_id(numpy_mod, "int64"),
-            uint32: get_attr_id(numpy_mod, "uint32"),
-            uint64: get_attr_id(numpy_mod, "uint64"),
-            bool: get_attr_id(builtins_mod, "bool"),
-            float: get_attr_id(builtins_mod, "float"),
-            float64: get_attr_id(numpy_mod, "float64"),
-            list: get_attr_id(builtins_mod, "list"),
-            tuple: get_attr_id(builtins_mod, "tuple"),
-            exception: get_attr_id(builtins_mod, "Exception"),
-            option: id_fn
-                .call1((builtins_mod
-                    .getattr("globals")
-                    .unwrap()
-                    .call0()
-                    .unwrap()
-                    .get_item("Option")
-                    .unwrap(),))
-                .unwrap()
-                .extract()
-                .unwrap(),
-        };
-
-        let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
-        fs::write(working_directory.path().join("kernel.ld"), include_bytes!("kernel.ld")).unwrap();
-
-        Ok(Nac3 {
-            isa,
-            time_fns,
-            primitive,
-            builtins,
-            primitive_ids,
-            top_levels: Default::default(),
-            pyid_to_def: Default::default(),
-            working_directory,
-            string_store: Default::default(),
-            exception_ids: Default::default(),
-            deferred_eval_store: DeferredEvaluationStore::new(),
-        })
-    }
-
-    fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
-        let (modules, class_ids) =
-            Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
-                let mut modules: HashMap<u64, PyObject> = HashMap::new();
-                let mut class_ids: HashSet<u64> = HashSet::new();
-
-                let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
-                let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
-
-                for function in functions.iter() {
-                    let module = getmodule_fn.call1((function,))?.extract()?;
-                    modules.insert(id_fn.call1((&module,))?.extract()?, module);
-                }
-                for class in classes.iter() {
-                    let module = getmodule_fn.call1((class,))?.extract()?;
-                    modules.insert(id_fn.call1((&module,))?.extract()?, module);
-                    class_ids.insert(id_fn.call1((class,))?.extract()?);
-                }
-                Ok((modules, class_ids))
-            })?;
-
-        for module in modules.into_values() {
-            self.register_module(module, &class_ids)?;
-        }
-        Ok(())
-    }
-
-    fn compile_method_to_file(
-        &mut self,
+    fn compile_method<T>(
+        &self,
         obj: &PyAny,
         method_name: &str,
         args: Vec<&PyAny>,
-        filename: &str,
         embedding_map: &PyAny,
         py: Python,
-    ) -> PyResult<()> {
+        link_fn: &dyn Fn(&Module) -> PyResult<T>,
+    ) -> PyResult<T> {
         let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
             self.builtins.clone(),
             ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" },
@@ -759,8 +562,6 @@ impl Nac3 {
             calls: Arc::new(Default::default()),
             id: 0,
         };
-        let isa = self.isa;
-        let working_directory = self.working_directory.path().to_owned();
 
         let membuffers: Arc<Mutex<Vec<Vec<u8>>>> = Default::default();
 
@@ -846,7 +647,13 @@ impl Nac3 {
         builder.populate_module_pass_manager(&passes);
         passes.run_on(&main);
 
-        let (triple, features) = match isa {
+        link_fn(&main)
+    }
+
+    fn get_llvm_target_machine(
+        &self,
+    ) -> TargetMachine {
+        let (triple, features) = match self.isa {
             Isa::Host => (
                 TargetMachine::get_default_triple(),
                 TargetMachine::get_host_cpu_features().to_string(),
@@ -862,7 +669,7 @@ impl Nac3 {
         };
         let target =
             Target::from_triple(&triple).expect("couldn't create target from target triple");
-        let target_machine = target
+        target
             .create_target_machine(
                 &triple,
                 "",
@@ -871,52 +678,282 @@ impl Nac3 {
                 RelocMode::PIC,
                 CodeModel::Default,
             )
-            .expect("couldn't create target machine");
-        
-        if isa == Isa::Host {
-            target_machine
-                .write_to_file(&main, FileType::Object, &working_directory.join("module.o"))
-                .expect("couldn't write module to file");
-            let linker_args = vec![
-                "-shared".to_string(),
-                "--eh-frame-hdr".to_string(),
-                "-x".to_string(),
-                "-o".to_string(),
-                filename.to_string(),
-                working_directory.join("module.o").to_string_lossy().to_string(),
-            ];
+            .expect("couldn't create target machine")
+    }
+}
 
-            #[cfg(not(windows))]
-            let lld_command = "ld.lld";
-            #[cfg(windows)]
-            let lld_command = "ld.lld.exe";
-            if let Ok(linker_status) = Command::new(lld_command).args(linker_args).status() {
-                if !linker_status.success() {
-                    return Err(CompileError::new_err("failed to start linker"));
-                }
-            } else {
-                return Err(CompileError::new_err(
-                    "linker returned non-zero status code",
-                ));
-            }
-        } else {
-            let object_mem = target_machine
-                .write_to_memory_buffer(&main, FileType::Object)
-                .expect("couldn't write module to object file buffer");
-            if let Ok(dyn_lib) = Linker::ld(object_mem.as_slice()) {
-                if let Ok(mut file) = fs::File::create(filename) {
-                    file.write_all(&dyn_lib).expect("couldn't write linked library to file");
-                } else {
-                    return Err(CompileError::new_err("failed to create file"));
-                }
-            } else {
-                return Err(CompileError::new_err("linker failed to process object file"));
-            }
+fn link_with_lld(
+    elf_filename: String,
+    obj_filename: String,
+) -> PyResult<()>{
+    let linker_args = vec![
+        "-shared".to_string(),
+        "--eh-frame-hdr".to_string(),
+        "-x".to_string(),
+        "-o".to_string(),
+        elf_filename,
+        obj_filename,
+    ];
+
+    #[cfg(not(windows))]
+    let lld_command = "ld.lld";
+    #[cfg(windows)]
+    let lld_command = "ld.lld.exe";
+    if let Ok(linker_status) = Command::new(lld_command).args(linker_args).status() {
+        if !linker_status.success() {
+            return Err(CompileError::new_err("failed to start linker"));
         }
+    } else {
+        return Err(CompileError::new_err(
+            "linker returned non-zero status code",
+        ));
+    }
 
+    Ok(())
+}
+
+fn add_exceptions(
+    composer: &mut TopLevelComposer,
+    builtin_def: &mut HashMap<StrRef, DefinitionId>,
+    builtin_ty: &mut HashMap<StrRef, Type>,
+    error_names: &[&str]
+) -> Vec<Type> {
+    let mut types = Vec::new();
+    // note: this is only for builtin exceptions, i.e. the exception name is "0:{exn}"
+    for name in error_names {
+        let def_id = composer.definition_ast_list.len();
+        let (exception_fn, exception_class, exception_cons, exception_type) = get_exn_constructor(
+            name,
+            // class id
+            def_id,
+            // constructor id
+            def_id + 1,
+            &mut composer.unifier,
+            &composer.primitives_ty
+        );
+        composer.definition_ast_list.push((Arc::new(RwLock::new(exception_class)), None));
+        composer.definition_ast_list.push((Arc::new(RwLock::new(exception_fn)), None));
+        builtin_ty.insert((*name).into(), exception_cons);
+        builtin_def.insert((*name).into(), DefinitionId(def_id));
+        types.push(exception_type);
+    }
+    types
+}
+
+#[pymethods]
+impl Nac3 {
+    #[new]
+    fn new(isa: &str, py: Python) -> PyResult<Self> {
+        let isa = match isa {
+            "host" => Isa::Host,
+            "rv32g" => Isa::RiscV32G,
+            "rv32ima" => Isa::RiscV32IMA,
+            "cortexa9" => Isa::CortexA9,
+            _ => return Err(exceptions::PyValueError::new_err("invalid ISA")),
+        };
+        let time_fns: &(dyn TimeFns + Sync) = match isa {
+            Isa::Host => &timeline::EXTERN_TIME_FNS,
+            Isa::RiscV32G => &timeline::NOW_PINNING_TIME_FNS_64,
+            Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
+            Isa::CortexA9 => &timeline::EXTERN_TIME_FNS,
+        };
+        let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0;
+        let builtins = vec![
+            (
+                "now_mu".into(),
+                FunSignature { args: vec![], ret: primitive.int64, vars: HashMap::new() },
+                Arc::new(GenCall::new(Box::new(move |ctx, _, _, _, _| {
+                    Ok(Some(time_fns.emit_now_mu(ctx)))
+                }))),
+            ),
+            (
+                "at_mu".into(),
+                FunSignature {
+                    args: vec![FuncArg {
+                        name: "t".into(),
+                        ty: primitive.int64,
+                        default_value: None,
+                    }],
+                    ret: primitive.none,
+                    vars: HashMap::new(),
+                },
+                Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
+                    let arg_ty = fun.0.args[0].ty;
+                    let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
+                    time_fns.emit_at_mu(ctx, arg);
+                    Ok(None)
+                }))),
+            ),
+            (
+                "delay_mu".into(),
+                FunSignature {
+                    args: vec![FuncArg {
+                        name: "dt".into(),
+                        ty: primitive.int64,
+                        default_value: None,
+                    }],
+                    ret: primitive.none,
+                    vars: HashMap::new(),
+                },
+                Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
+                    let arg_ty = fun.0.args[0].ty;
+                    let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
+                    time_fns.emit_delay_mu(ctx, arg);
+                    Ok(None)
+                }))),
+            ),
+        ];
+
+        let builtins_mod = PyModule::import(py, "builtins").unwrap();
+        let id_fn = builtins_mod.getattr("id").unwrap();
+        let numpy_mod = PyModule::import(py, "numpy").unwrap();
+        let typing_mod = PyModule::import(py, "typing").unwrap();
+        let types_mod = PyModule::import(py, "types").unwrap();
+
+        let get_id = |x| id_fn.call1((x,)).unwrap().extract().unwrap();
+        let get_attr_id = |obj: &PyModule, attr| id_fn.call1((obj.getattr(attr).unwrap(),))
+            .unwrap().extract().unwrap();
+        let primitive_ids = PrimitivePythonId {
+            virtual_id: get_id(
+                            builtins_mod
+                            .getattr("globals")
+                            .unwrap()
+                            .call0()
+                            .unwrap()
+                            .get_item("virtual")
+                            .unwrap(
+                        )),
+            generic_alias: (
+                get_attr_id(typing_mod, "_GenericAlias"),
+                get_attr_id(types_mod, "GenericAlias"),
+            ),
+            none: id_fn
+                .call1((builtins_mod
+                    .getattr("globals")
+                    .unwrap()
+                    .call0()
+                    .unwrap()
+                    .get_item("none")
+                    .unwrap(),))
+                .unwrap()
+                .extract()
+                .unwrap(),
+            typevar: get_attr_id(typing_mod, "TypeVar"),
+            int: get_attr_id(builtins_mod, "int"),
+            int32: get_attr_id(numpy_mod, "int32"),
+            int64: get_attr_id(numpy_mod, "int64"),
+            uint32: get_attr_id(numpy_mod, "uint32"),
+            uint64: get_attr_id(numpy_mod, "uint64"),
+            bool: get_attr_id(builtins_mod, "bool"),
+            float: get_attr_id(builtins_mod, "float"),
+            float64: get_attr_id(numpy_mod, "float64"),
+            list: get_attr_id(builtins_mod, "list"),
+            tuple: get_attr_id(builtins_mod, "tuple"),
+            exception: get_attr_id(builtins_mod, "Exception"),
+            option: id_fn
+                .call1((builtins_mod
+                    .getattr("globals")
+                    .unwrap()
+                    .call0()
+                    .unwrap()
+                    .get_item("Option")
+                    .unwrap(),))
+                .unwrap()
+                .extract()
+                .unwrap(),
+        };
+
+        let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
+        fs::write(working_directory.path().join("kernel.ld"), include_bytes!("kernel.ld")).unwrap();
+
+        Ok(Nac3 {
+            isa,
+            time_fns,
+            primitive,
+            builtins,
+            primitive_ids,
+            top_levels: Default::default(),
+            pyid_to_def: Default::default(),
+            working_directory,
+            string_store: Default::default(),
+            exception_ids: Default::default(),
+            deferred_eval_store: DeferredEvaluationStore::new(),
+        })
+    }
+
+    fn analyze(&mut self, functions: &PySet, classes: &PySet) -> PyResult<()> {
+        let (modules, class_ids) =
+            Python::with_gil(|py| -> PyResult<(HashMap<u64, PyObject>, HashSet<u64>)> {
+                let mut modules: HashMap<u64, PyObject> = HashMap::new();
+                let mut class_ids: HashSet<u64> = HashSet::new();
+
+                let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
+                let getmodule_fn = PyModule::import(py, "inspect")?.getattr("getmodule")?;
+
+                for function in functions.iter() {
+                    let module = getmodule_fn.call1((function,))?.extract()?;
+                    modules.insert(id_fn.call1((&module,))?.extract()?, module);
+                }
+                for class in classes.iter() {
+                    let module = getmodule_fn.call1((class,))?.extract()?;
+                    modules.insert(id_fn.call1((&module,))?.extract()?, module);
+                    class_ids.insert(id_fn.call1((class,))?.extract()?);
+                }
+                Ok((modules, class_ids))
+            })?;
+
+        for module in modules.into_values() {
+            self.register_module(module, &class_ids)?;
+        }
         Ok(())
     }
 
+    fn compile_method_to_file(
+        &mut self,
+        obj: &PyAny,
+        method_name: &str,
+        args: Vec<&PyAny>,
+        filename: &str,
+        embedding_map: &PyAny,
+        py: Python,
+    ) -> PyResult<()> {
+        let target_machine = self.get_llvm_target_machine();
+        
+        if self.isa == Isa::Host {
+            let link_fn = |module: &Module| {
+                let working_directory = self.working_directory.path().to_owned();
+                target_machine
+                    .write_to_file(module, FileType::Object, &working_directory.join("module.o"))
+                    .expect("couldn't write module to file");
+                link_with_lld(
+                    filename.to_string(),
+                    working_directory.join("module.o").to_string_lossy().to_string()
+                )?;
+                Ok(())
+            };
+
+            self.compile_method(obj, method_name, args, embedding_map, py, &link_fn)
+        } else {
+            let link_fn = |module: &Module| {
+                let object_mem = target_machine
+                    .write_to_memory_buffer(module, FileType::Object)
+                    .expect("couldn't write module to object file buffer");
+                if let Ok(dyn_lib) = Linker::ld(object_mem.as_slice()) {
+                    if let Ok(mut file) = fs::File::create(filename) {
+                        file.write_all(&dyn_lib).expect("couldn't write linked library to file");
+                        Ok(())
+                    } else {
+                        Err(CompileError::new_err("failed to create file"))
+                    }
+                } else {
+                    Err(CompileError::new_err("linker failed to process object file"))
+                }
+            };
+
+            self.compile_method(obj, method_name, args, embedding_map, py, &link_fn)
+        }
+    }
+
     fn compile_method_to_mem(
         &mut self,
         obj: &PyAny,
@@ -925,10 +962,40 @@ impl Nac3 {
         embedding_map: &PyAny,
         py: Python,
     ) -> PyResult<PyObject> {
-        let filename_path = self.working_directory.path().join("module.elf");
-        let filename = filename_path.to_str().unwrap();
-        self.compile_method_to_file(obj, method_name, args, filename, embedding_map, py)?;
-        Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())
+        let target_machine = self.get_llvm_target_machine();
+        
+        if self.isa == Isa::Host {
+            let link_fn = |module: &Module| {
+                let working_directory = self.working_directory.path().to_owned();
+                target_machine
+                    .write_to_file(&module, FileType::Object, &working_directory.join("module.o"))
+                    .expect("couldn't write module to file");
+
+                let filename_path = self.working_directory.path().join("module.elf");
+                let filename = filename_path.to_str().unwrap();
+                link_with_lld(
+                    filename.to_string(),
+                    working_directory.join("module.o").to_string_lossy().to_string()
+                )?;
+
+                Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())
+            };
+
+            self.compile_method(obj, method_name, args, embedding_map, py, &link_fn)
+        } else {
+            let link_fn = |module: &Module| {
+                let object_mem = target_machine
+                    .write_to_memory_buffer(&module, FileType::Object)
+                    .expect("couldn't write module to object file buffer");
+                if let Ok(dyn_lib) = Linker::ld(object_mem.as_slice()) {
+                    Ok(PyBytes::new(py, &dyn_lib).into())
+                } else {
+                    Err(CompileError::new_err("linker failed to process object file"))
+                }
+            };
+
+            self.compile_method(obj, method_name, args, embedding_map, py, &link_fn)
+        }
     }
 }