forked from M-Labs/nac3
nac3artiq: implement attribute writeback
We will only writeback attributes that are supported by the current RPC implementation: primitives, tuple and lists of lists... of primitives.
This commit is contained in:
parent
ba8ed6c663
commit
bf067e2481
|
@ -5,6 +5,7 @@ class EmbeddingMap:
|
|||
self.string_map = {}
|
||||
self.string_reverse_map = {}
|
||||
self.function_map = {}
|
||||
self.attributes_writeback = []
|
||||
|
||||
# preallocate exception names
|
||||
self.preallocate_runtime_exception_names(["RuntimeError",
|
||||
|
|
|
@ -6,7 +6,7 @@ use nac3core::{
|
|||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::{DefinitionId, GenCall},
|
||||
typecheck::typedef::{FunSignature, Type},
|
||||
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum}
|
||||
};
|
||||
|
||||
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
|
||||
|
@ -15,7 +15,9 @@ use inkwell::{
|
|||
context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace,
|
||||
};
|
||||
|
||||
use crate::timeline::TimeFns;
|
||||
use pyo3::{PyObject, PyResult, Python, types::{PyDict, PyList}};
|
||||
|
||||
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
|
||||
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
|
@ -270,8 +272,6 @@ fn gen_rpc_tag<'ctx, 'a>(
|
|||
buffer.push(b'l');
|
||||
gen_rpc_tag(ctx, *ty, buffer)?;
|
||||
}
|
||||
// we should return an error, this will be fixed after improving error message
|
||||
// as this requires returning an error during codegen
|
||||
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
||||
}
|
||||
}
|
||||
|
@ -291,7 +291,7 @@ fn rpc_codegen_callback_fn<'ctx, 'a>(
|
|||
let int32 = ctx.ctx.i32_type();
|
||||
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
|
||||
|
||||
let service_id = int32.const_int(fun.1 .0 as u64, false);
|
||||
let service_id = int32.const_int(fun.1.0 as u64, false);
|
||||
// -- setup rpc tags
|
||||
let mut tag = Vec::new();
|
||||
if obj.is_some() {
|
||||
|
@ -486,6 +486,81 @@ fn rpc_codegen_callback_fn<'ctx, 'a>(
|
|||
Ok(Some(result))
|
||||
}
|
||||
|
||||
pub fn attributes_writeback<'ctx, 'a>(
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
inner_resolver: &InnerResolver,
|
||||
host_attributes: PyObject,
|
||||
) -> Result<(), String> {
|
||||
Python::with_gil(|py| -> PyResult<Result<(), String>> {
|
||||
let host_attributes = host_attributes.cast_as::<PyList>(py)?;
|
||||
let top_levels = ctx.top_level.definitions.read();
|
||||
let globals = inner_resolver.global_value_ids.read();
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let zero = int32.const_zero();
|
||||
let mut values = Vec::new();
|
||||
let mut scratch_buffer = Vec::new();
|
||||
for (_, val) in globals.iter() {
|
||||
let val = val.as_ref(py);
|
||||
let ty = inner_resolver.get_obj_type(py, val, &mut ctx.unifier, &top_levels, &ctx.primitives)?;
|
||||
if let Err(ty) = ty {
|
||||
return Ok(Err(ty))
|
||||
}
|
||||
let ty = ty.unwrap();
|
||||
match &*ctx.unifier.get_ty(ty) {
|
||||
TypeEnum::TObj { fields, .. } => {
|
||||
// we only care about primitive attributes
|
||||
// for non-primitive attributes, they should be in another global
|
||||
let mut attributes = Vec::new();
|
||||
let obj = inner_resolver.get_obj_value(py, val, ctx, generator)?.unwrap();
|
||||
for (name, (field_ty, is_mutable)) in fields.iter() {
|
||||
if !is_mutable {
|
||||
continue
|
||||
}
|
||||
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
|
||||
attributes.push(name.to_string());
|
||||
let index = ctx.get_attr_index(ty, *name);
|
||||
values.push((*field_ty, ctx.build_gep_and_load(
|
||||
obj.into_pointer_value(),
|
||||
&[zero, int32.const_int(index as u64, false)])));
|
||||
}
|
||||
}
|
||||
if !attributes.is_empty() {
|
||||
let pydict = PyDict::new(py);
|
||||
pydict.set_item("obj", val)?;
|
||||
pydict.set_item("fields", attributes)?;
|
||||
host_attributes.append(pydict)?;
|
||||
}
|
||||
},
|
||||
TypeEnum::TList { ty: elem_ty } => {
|
||||
if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() {
|
||||
let pydict = PyDict::new(py);
|
||||
pydict.set_item("obj", val)?;
|
||||
host_attributes.append(pydict)?;
|
||||
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator)?.unwrap()));
|
||||
}
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
let fun = FunSignature {
|
||||
args: values.iter().enumerate().map(|(i, (ty, _))| FuncArg {
|
||||
name: i.to_string().into(),
|
||||
ty: *ty,
|
||||
default_value: None
|
||||
}).collect(),
|
||||
ret: ctx.primitives.none,
|
||||
vars: Default::default()
|
||||
};
|
||||
let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
|
||||
if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, DefinitionId(0)), args, generator) {
|
||||
return Ok(Err(e));
|
||||
}
|
||||
Ok(Ok(()))
|
||||
}).unwrap()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn rpc_codegen_callback() -> Arc<GenCall> {
|
||||
Arc::new(GenCall::new(Box::new(|ctx, obj, fun, args, generator| {
|
||||
rpc_codegen_callback_fn(ctx, obj, fun, args, generator)
|
||||
|
|
|
@ -10,6 +10,7 @@ use inkwell::{
|
|||
targets::*,
|
||||
OptimizationLevel,
|
||||
};
|
||||
use nac3core::codegen::gen_func_impl;
|
||||
use nac3core::toplevel::builtins::get_exn_constructor;
|
||||
use nac3core::typecheck::typedef::{TypeEnum, Unifier};
|
||||
use nac3parser::{
|
||||
|
@ -36,6 +37,7 @@ use nac3core::{
|
|||
|
||||
use tempfile::{self, TempDir};
|
||||
|
||||
use crate::codegen::attributes_writeback;
|
||||
use crate::{
|
||||
codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
|
||||
symbol_resolver::{InnerResolver, PythonHelper, Resolver, DeferredEvaluationStore},
|
||||
|
@ -476,6 +478,8 @@ impl Nac3 {
|
|||
let store_obj = embedding_map.getattr("store_object").unwrap().to_object(py);
|
||||
let store_str = embedding_map.getattr("store_str").unwrap().to_object(py);
|
||||
let store_fun = embedding_map.getattr("store_function").unwrap().to_object(py);
|
||||
let host_attributes = embedding_map.getattr("attributes_writeback").unwrap().to_object(py);
|
||||
let global_value_ids: Arc<RwLock<HashMap<_, _>>> = Arc::new(RwLock::new(HashMap::new()));
|
||||
let helper = PythonHelper {
|
||||
id_fn: builtins.getattr("id").unwrap().to_object(py),
|
||||
len_fn: builtins.getattr("len").unwrap().to_object(py),
|
||||
|
@ -503,7 +507,6 @@ impl Nac3 {
|
|||
|
||||
let mut module_to_resolver_cache: HashMap<u64, _> = HashMap::new();
|
||||
|
||||
let global_value_ids = Arc::new(RwLock::new(HashSet::<u64>::new()));
|
||||
let mut rpc_ids = vec![];
|
||||
for (stmt, path, module) in self.top_levels.iter() {
|
||||
let py_module: &PyAny = module.extract(py)?;
|
||||
|
@ -617,7 +620,7 @@ impl Nac3 {
|
|||
};
|
||||
let mut synthesized =
|
||||
parse_program(&synthesized, "__nac3_synthesized_modinit__".to_string().into()).unwrap();
|
||||
let resolver = Arc::new(Resolver(Arc::new(InnerResolver {
|
||||
let inner_resolver = Arc::new(InnerResolver {
|
||||
id_to_type: builtins_ty.clone().into(),
|
||||
id_to_def: builtins_def.clone().into(),
|
||||
pyid_to_def: self.pyid_to_def.clone(),
|
||||
|
@ -634,17 +637,18 @@ impl Nac3 {
|
|||
string_store: self.string_store.clone(),
|
||||
exception_ids: self.exception_ids.clone(),
|
||||
deferred_eval_store: self.deferred_eval_store.clone(),
|
||||
}))) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
});
|
||||
let resolver = Arc::new(Resolver(inner_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
let (_, def_id, _) = composer
|
||||
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "".into())
|
||||
.unwrap();
|
||||
|
||||
let signature =
|
||||
let fun_signature =
|
||||
FunSignature { args: vec![], ret: self.primitive.none, vars: HashMap::new() };
|
||||
let mut store = ConcreteTypeStore::new();
|
||||
let mut cache = HashMap::new();
|
||||
let signature =
|
||||
store.from_signature(&mut composer.unifier, &self.primitive, &signature, &mut cache);
|
||||
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache);
|
||||
let signature = store.add_cty(signature);
|
||||
|
||||
if let Err(e) = composer.start_analysis(true) {
|
||||
|
@ -721,12 +725,29 @@ impl Nac3 {
|
|||
symbol_name: "__modinit__".to_string(),
|
||||
body: instance.body,
|
||||
signature,
|
||||
resolver,
|
||||
resolver: resolver.clone(),
|
||||
store,
|
||||
unifier_index: instance.unifier_id,
|
||||
calls: instance.calls,
|
||||
id: 0,
|
||||
};
|
||||
|
||||
let mut store = ConcreteTypeStore::new();
|
||||
let mut cache = HashMap::new();
|
||||
let signature =
|
||||
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache);
|
||||
let signature = store.add_cty(signature);
|
||||
let attributes_writeback_task = CodeGenTask {
|
||||
subst: Default::default(),
|
||||
symbol_name: "attributes_writeback".to_string(),
|
||||
body: Arc::new(Default::default()),
|
||||
signature,
|
||||
resolver,
|
||||
store,
|
||||
unifier_index: instance.unifier_id,
|
||||
calls: Arc::new(Default::default()),
|
||||
id: 0,
|
||||
};
|
||||
let isa = self.isa;
|
||||
let working_directory = self.working_directory.path().to_owned();
|
||||
|
||||
|
@ -746,14 +767,27 @@ impl Nac3 {
|
|||
.map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), size_t, self.time_fns)))
|
||||
.collect();
|
||||
|
||||
let membuffer = membuffers.clone();
|
||||
py.allow_threads(|| {
|
||||
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level.clone(), f);
|
||||
registry.add_task(task);
|
||||
registry.wait_tasks_complete(handles);
|
||||
|
||||
let mut generator = ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
|
||||
let context = inkwell::context::Context::create();
|
||||
let module = context.create_module("attributes_writeback");
|
||||
let builder = context.create_builder();
|
||||
let (_, module, _) = gen_func_impl(&context, &mut generator, ®istry, builder, module,
|
||||
attributes_writeback_task, |generator, ctx| {
|
||||
attributes_writeback(ctx, generator, inner_resolver.as_ref(), host_attributes)
|
||||
}).unwrap();
|
||||
let buffer = module.write_bitcode_to_memory();
|
||||
let buffer = buffer.as_slice().into();
|
||||
membuffer.lock().push(buffer);
|
||||
});
|
||||
|
||||
let buffers = membuffers.lock();
|
||||
let context = inkwell::context::Context::create();
|
||||
let buffers = membuffers.lock();
|
||||
let main = context
|
||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
|
||||
.unwrap();
|
||||
|
@ -765,6 +799,11 @@ impl Nac3 {
|
|||
main.link_in_module(other)
|
||||
.map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||
}
|
||||
let builder = context.create_builder();
|
||||
let modinit_return = main.get_function("__modinit__").unwrap().get_last_basic_block().unwrap().get_terminator().unwrap();
|
||||
builder.position_before(&modinit_return);
|
||||
builder.build_call(main.get_function("attributes_writeback").unwrap(), &[], "attributes_writeback");
|
||||
|
||||
main.link_in_module(load_irrt(&context))
|
||||
.map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ use pyo3::{
|
|||
PyAny, PyObject, PyResult, Python,
|
||||
};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
collections::HashMap,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicBool, Ordering::Relaxed}
|
||||
|
@ -54,7 +54,7 @@ pub struct InnerResolver {
|
|||
pub id_to_pyval: RwLock<HashMap<StrRef, (u64, PyObject)>>,
|
||||
pub id_to_primitive: RwLock<HashMap<u64, PrimitiveValue>>,
|
||||
pub field_to_val: RwLock<HashMap<(u64, StrRef), Option<(u64, PyObject)>>>,
|
||||
pub global_value_ids: Arc<RwLock<HashSet<u64>>>,
|
||||
pub global_value_ids: Arc<RwLock<HashMap<u64, PyObject>>>,
|
||||
pub class_names: Mutex<HashMap<StrRef, Type>>,
|
||||
pub pyid_to_def: Arc<RwLock<HashMap<u64, DefinitionId>>>,
|
||||
pub pyid_to_type: Arc<RwLock<HashMap<u64, Type>>>,
|
||||
|
@ -503,7 +503,7 @@ impl InnerResolver {
|
|||
}
|
||||
}
|
||||
|
||||
fn get_obj_type(
|
||||
pub fn get_obj_type(
|
||||
&self,
|
||||
py: Python,
|
||||
obj: &PyAny,
|
||||
|
@ -605,7 +605,7 @@ impl InnerResolver {
|
|||
unreachable!("must be tobj")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
|
@ -686,7 +686,7 @@ impl InnerResolver {
|
|||
}
|
||||
}
|
||||
|
||||
fn get_obj_value<'ctx, 'a>(
|
||||
pub fn get_obj_value<'ctx, 'a>(
|
||||
&self,
|
||||
py: Python,
|
||||
obj: &PyAny,
|
||||
|
@ -754,13 +754,13 @@ impl InnerResolver {
|
|||
.struct_type(&[ty.ptr_type(AddressSpace::Generic).into(), size_t.into()], false);
|
||||
|
||||
{
|
||||
if self.global_value_ids.read().contains(&id) {
|
||||
if self.global_value_ids.read().contains_key(&id) {
|
||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
||||
ctx.module.add_global(arr_ty, Some(AddressSpace::Generic), &id_str)
|
||||
});
|
||||
return Ok(Some(global.as_pointer_value().into()));
|
||||
} else {
|
||||
self.global_value_ids.write().insert(id);
|
||||
self.global_value_ids.write().insert(id, obj.into());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -834,13 +834,13 @@ impl InnerResolver {
|
|||
let ty = ctx.ctx.struct_type(&types, false);
|
||||
|
||||
{
|
||||
if self.global_value_ids.read().contains(&id) {
|
||||
if self.global_value_ids.read().contains_key(&id) {
|
||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
||||
ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str)
|
||||
});
|
||||
return Ok(Some(global.as_pointer_value().into()));
|
||||
} else {
|
||||
self.global_value_ids.write().insert(id);
|
||||
self.global_value_ids.write().insert(id, obj.into());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -869,13 +869,13 @@ impl InnerResolver {
|
|||
Some(v) => {
|
||||
let global_str = format!("{}_option", id);
|
||||
{
|
||||
if self.global_value_ids.read().contains(&id) {
|
||||
if self.global_value_ids.read().contains_key(&id) {
|
||||
let global = ctx.module.get_global(&global_str).unwrap_or_else(|| {
|
||||
ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str)
|
||||
});
|
||||
return Ok(Some(global.as_pointer_value().into()));
|
||||
} else {
|
||||
self.global_value_ids.write().insert(id);
|
||||
self.global_value_ids.write().insert(id, obj.into());
|
||||
}
|
||||
}
|
||||
let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str);
|
||||
|
@ -902,13 +902,13 @@ impl InnerResolver {
|
|||
.get_element_type()
|
||||
.into_struct_type();
|
||||
{
|
||||
if self.global_value_ids.read().contains(&id) {
|
||||
if self.global_value_ids.read().contains_key(&id) {
|
||||
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
|
||||
ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str)
|
||||
});
|
||||
return Ok(Some(global.as_pointer_value().into()));
|
||||
} else {
|
||||
self.global_value_ids.write().insert(id);
|
||||
self.global_value_ids.write().insert(id, obj.into());
|
||||
}
|
||||
}
|
||||
// should be classes
|
||||
|
|
|
@ -355,9 +355,9 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn gen_string<G: CodeGenerator, S: Into<String>>(
|
||||
pub fn gen_string<S: Into<String>>(
|
||||
&mut self,
|
||||
generator: &mut G,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
s: S,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
self.gen_const(generator, &nac3parser::ast::Constant::Str(s.into()), self.primitives.str)
|
||||
|
|
|
@ -360,13 +360,14 @@ fn need_sret<'ctx>(ctx: &'ctx Context, ty: BasicTypeEnum<'ctx>) -> bool {
|
|||
need_sret_impl(ctx, ty, true)
|
||||
}
|
||||
|
||||
pub fn gen_func<'ctx, G: CodeGenerator>(
|
||||
pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>> (
|
||||
context: &'ctx Context,
|
||||
generator: &mut G,
|
||||
registry: &WorkerRegistry,
|
||||
builder: Builder<'ctx>,
|
||||
module: Module<'ctx>,
|
||||
task: CodeGenTask,
|
||||
codegen_function: F
|
||||
) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> {
|
||||
let top_level_ctx = registry.top_level_ctx.clone();
|
||||
let static_value_store = registry.static_value_store.clone();
|
||||
|
@ -572,25 +573,34 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
|
|||
need_sret: has_sret
|
||||
};
|
||||
|
||||
let mut err = None;
|
||||
for stmt in task.body.iter() {
|
||||
if let Err(e) = generator.gen_stmt(&mut code_gen_context, stmt) {
|
||||
err = Some(e);
|
||||
break;
|
||||
}
|
||||
if code_gen_context.is_terminated() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let result = codegen_function(generator, &mut code_gen_context);
|
||||
|
||||
// after static analysis, only void functions can have no return at the end.
|
||||
if !code_gen_context.is_terminated() {
|
||||
code_gen_context.builder.build_return(None);
|
||||
}
|
||||
|
||||
let CodeGenContext { builder, module, .. } = code_gen_context;
|
||||
if let Some(e) = err {
|
||||
if let Err(e) = result {
|
||||
return Err((builder, e));
|
||||
}
|
||||
|
||||
Ok((builder, module, fn_val))
|
||||
}
|
||||
|
||||
pub fn gen_func<'ctx, G: CodeGenerator>(
|
||||
context: &'ctx Context,
|
||||
generator: &mut G,
|
||||
registry: &WorkerRegistry,
|
||||
builder: Builder<'ctx>,
|
||||
module: Module<'ctx>,
|
||||
task: CodeGenTask,
|
||||
) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> {
|
||||
let body = task.body.clone();
|
||||
gen_func_impl(context, generator, registry, builder, module, task, |generator, ctx| {
|
||||
for stmt in body.iter() {
|
||||
generator.gen_stmt(ctx, stmt)?;
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue