forked from M-Labs/nac3
Compare commits
20 Commits
Author | SHA1 | Date | |
---|---|---|---|
8031e9143c | |||
03be4697c5 | |||
6328e1e740 | |||
0e61c9f658 | |||
983017a693 | |||
5055c02200 | |||
b1106a6f20 | |||
74ffff6a8b | |||
ce52c17781 | |||
22d09939a6 | |||
1426ee20ef | |||
bb3e6b64b8 | |||
639c31700e | |||
3ce649c0b7 | |||
6b1cba59e0 | |||
a2360870b7 | |||
f3324de3b9 | |||
cb7f667e4d | |||
59c7478950 | |||
ee72072c16 |
27
nac3artiq/demo/demo_import_mod.py
Normal file
27
nac3artiq/demo/demo_import_mod.py
Normal file
@ -0,0 +1,27 @@
|
||||
from min_artiq import kernel, KernelInvariant, nac3
|
||||
import min_artiq as artiq
|
||||
|
||||
|
||||
@nac3
|
||||
class Demo:
|
||||
core: KernelInvariant[artiq.Core]
|
||||
led0: KernelInvariant[artiq.TTLOut]
|
||||
led1: KernelInvariant[artiq.TTLOut]
|
||||
|
||||
def __init__(self):
|
||||
self.core = artiq.Core()
|
||||
self.led0 = artiq.TTLOut(self.core, 18)
|
||||
self.led1 = artiq.TTLOut(self.core, 19)
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
self.core.reset()
|
||||
while True:
|
||||
with artiq.parallel:
|
||||
self.led0.pulse(100.*artiq.ms)
|
||||
self.led1.pulse(100.*artiq.ms)
|
||||
self.core.delay(100.*artiq.ms)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Demo().run()
|
@ -1,9 +1,9 @@
|
||||
from inspect import getfullargspec
|
||||
from functools import wraps
|
||||
from types import SimpleNamespace
|
||||
from numpy import int32, int64
|
||||
from typing import Generic, TypeVar
|
||||
from math import floor, ceil
|
||||
from numpy import int32, int64, uint32, uint64, float64, bool_, str_, ndarray
|
||||
from types import GenericAlias, ModuleType, SimpleNamespace
|
||||
from typing import _GenericAlias, Generic, TypeVar
|
||||
|
||||
import nac3artiq
|
||||
|
||||
@ -85,13 +85,46 @@ def ceil64(x):
|
||||
import device_db
|
||||
core_arguments = device_db.device_db["core"]["arguments"]
|
||||
|
||||
artiq_builtins = {
|
||||
builtins = {
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"str": str,
|
||||
"list": list,
|
||||
"tuple": tuple,
|
||||
"Exception": Exception,
|
||||
|
||||
"types": {
|
||||
"GenericAlias": GenericAlias,
|
||||
"ModuleType": ModuleType,
|
||||
},
|
||||
|
||||
"typing": {
|
||||
"_GenericAlias": _GenericAlias,
|
||||
"TypeVar": TypeVar,
|
||||
},
|
||||
|
||||
"numpy": {
|
||||
"int32": int32,
|
||||
"int64": int64,
|
||||
"uint32": uint32,
|
||||
"uint64": uint64,
|
||||
"float64": float64,
|
||||
"bool_": bool_,
|
||||
"str_": str_,
|
||||
"ndarray": ndarray,
|
||||
},
|
||||
|
||||
"artiq": {
|
||||
"Kernel": Kernel,
|
||||
"KernelInvariant": KernelInvariant,
|
||||
"_ConstGenericMarker": _ConstGenericMarker,
|
||||
"none": none,
|
||||
"virtual": virtual,
|
||||
"_ConstGenericMarker": _ConstGenericMarker,
|
||||
"Option": Option,
|
||||
},
|
||||
}
|
||||
compiler = nac3artiq.NAC3(core_arguments["target"], artiq_builtins)
|
||||
compiler = nac3artiq.NAC3(core_arguments["target"], builtins)
|
||||
allow_registration = True
|
||||
# Delay NAC3 analysis until all referenced variables are supposed to exist on the CPython side.
|
||||
registered_functions = set()
|
||||
@ -152,9 +185,9 @@ def nac3(cls):
|
||||
return cls
|
||||
|
||||
|
||||
ms = 1e-3
|
||||
us = 1e-6
|
||||
ns = 1e-9
|
||||
ms: KernelInvariant[float] = 1e-3
|
||||
us: KernelInvariant[float] = 1e-6
|
||||
ns: KernelInvariant[float] = 1e-9
|
||||
|
||||
@extern
|
||||
def rtio_init():
|
||||
@ -335,9 +368,9 @@ class UnwrapNoneError(Exception):
|
||||
"""raised when unwrapping a none value"""
|
||||
artiq_builtin = True
|
||||
|
||||
parallel = KernelContextManager()
|
||||
legacy_parallel = KernelContextManager()
|
||||
sequential = KernelContextManager()
|
||||
parallel: KernelInvariant[KernelContextManager] = KernelContextManager()
|
||||
legacy_parallel: KernelInvariant[KernelContextManager] = KernelContextManager()
|
||||
sequential: KernelInvariant[KernelContextManager] = KernelContextManager()
|
||||
|
||||
special_ids = {
|
||||
"parallel": id(parallel),
|
||||
|
@ -1015,7 +1015,7 @@ pub fn attributes_writeback<'ctx>(
|
||||
*field_ty,
|
||||
ctx.build_gep_and_load(
|
||||
obj.into_pointer_value(),
|
||||
&[zero, int32.const_int(index as u64, false)],
|
||||
&[zero, int32.const_int(index.unwrap() as u64, false)],
|
||||
None,
|
||||
),
|
||||
));
|
||||
@ -1056,7 +1056,7 @@ pub fn attributes_writeback<'ctx>(
|
||||
*field_ty,
|
||||
ctx.build_gep_and_load(
|
||||
obj.into_pointer_value(),
|
||||
&[zero, int32.const_int(index as u64, false)],
|
||||
&[zero, int32.const_int(index.unwrap() as u64, false)],
|
||||
None,
|
||||
),
|
||||
));
|
||||
|
273
nac3artiq/src/debug.rs
Normal file
273
nac3artiq/src/debug.rs
Normal file
@ -0,0 +1,273 @@
|
||||
use itertools::Itertools;
|
||||
|
||||
use nac3core::{toplevel::TopLevelDef, typecheck::typedef::Unifier};
|
||||
|
||||
use super::{InnerResolver, symbol_resolver::PyValueHandle};
|
||||
|
||||
impl InnerResolver {
|
||||
pub fn debug_str(&self, tld: Option<&[TopLevelDef]>, unifier: &Option<&mut Unifier>) -> String {
|
||||
fn fmt_elems(elems: &str) -> String {
|
||||
if elems.is_empty() { String::new() } else { format!("\n{elems}\n\t") }
|
||||
}
|
||||
fn stringify_pyvalue_handle(handle: &PyValueHandle) -> String {
|
||||
format!("(id: {}, value: {})", handle.0, handle.1)
|
||||
}
|
||||
fn stringify_tld(tld: &TopLevelDef) -> String {
|
||||
match tld {
|
||||
TopLevelDef::Module { name, .. } => {
|
||||
format!("TopLevelDef::Module {{ name: {name} }}")
|
||||
}
|
||||
TopLevelDef::Class { name, .. } => {
|
||||
format!("TopLevelDef::Class {{ name: {name} }}")
|
||||
}
|
||||
TopLevelDef::Function { name, .. } => {
|
||||
format!("TopLevelDef::Function {{ name: {name} }}")
|
||||
}
|
||||
TopLevelDef::Variable { name, .. } => {
|
||||
format!("TopLevelDef::Variable {{ name: {name} }}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut str = String::new();
|
||||
str.push_str("nac3artiq::InnerResolver {");
|
||||
|
||||
{
|
||||
let id_to_type = self.id_to_type.read();
|
||||
str.push_str(
|
||||
format!(
|
||||
"\n\tid_to_type: {{{}}},",
|
||||
fmt_elems(
|
||||
id_to_type
|
||||
.iter()
|
||||
.sorted_by_cached_key(|(k, _)| k.to_string())
|
||||
.map(|(k, v)| {
|
||||
let ty_str = unifier.as_ref().map_or_else(
|
||||
|| format!("{v:?}"),
|
||||
|unifier| unifier.stringify(*v),
|
||||
);
|
||||
format!("\t\t{k} -> {ty_str}")
|
||||
})
|
||||
.join(",\n")
|
||||
.as_str()
|
||||
),
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let id_to_def = self.id_to_def.read();
|
||||
str.push_str(
|
||||
format!(
|
||||
"\n\tid_to_def: {{{}}},",
|
||||
fmt_elems(
|
||||
id_to_def
|
||||
.iter()
|
||||
.sorted_by_cached_key(|(k, _)| k.to_string())
|
||||
.map(|(k, v)| {
|
||||
let tld_str = tld.map_or_else(
|
||||
|| format!("{v:?}"),
|
||||
|tlds| stringify_tld(&tlds[v.0]),
|
||||
);
|
||||
format!("\t\t{k} -> {tld_str}")
|
||||
})
|
||||
.join(",\n")
|
||||
.as_str()
|
||||
)
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let id_to_pyval = self.id_to_pyval.read();
|
||||
str.push_str(
|
||||
format!(
|
||||
"\n\tid_to_pyval: {{{}}},",
|
||||
fmt_elems(
|
||||
id_to_pyval
|
||||
.iter()
|
||||
.sorted_by_cached_key(|(k, _)| k.to_string())
|
||||
.map(|(k, v)| { format!("\t\t{k} -> {}", stringify_pyvalue_handle(v)) })
|
||||
.join(",\n")
|
||||
.as_str()
|
||||
)
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let id_to_primitive = self.id_to_primitive.read();
|
||||
str.push_str(
|
||||
format!(
|
||||
"\n\tid_to_primitive: {{{}}},",
|
||||
fmt_elems(
|
||||
id_to_primitive
|
||||
.iter()
|
||||
.sorted_by_key(|(k, _)| *k)
|
||||
.map(|(k, v)| { format!("\t\t{k} -> {v:?}") })
|
||||
.join(",\n")
|
||||
.as_str()
|
||||
)
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let field_to_val = self.field_to_val.read();
|
||||
str.push_str(
|
||||
format!(
|
||||
"\n\tfield_to_val: {{{}}},",
|
||||
fmt_elems(
|
||||
field_to_val
|
||||
.iter()
|
||||
.sorted_by_key(|((id, _), _)| *id)
|
||||
.map(|((id, name), pyval)| {
|
||||
format!(
|
||||
"\t\t({id}, {name}) -> {}",
|
||||
pyval.as_ref().map_or_else(
|
||||
|| String::from("None"),
|
||||
|pyval| format!(
|
||||
"Some({})",
|
||||
stringify_pyvalue_handle(pyval)
|
||||
)
|
||||
)
|
||||
)
|
||||
})
|
||||
.join(",\n")
|
||||
.as_str()
|
||||
)
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let global_value_ids = self.global_value_ids.read();
|
||||
str.push_str(
|
||||
format!(
|
||||
"\n\tglobal_value_ids: {{{}}},",
|
||||
fmt_elems(
|
||||
global_value_ids
|
||||
.iter()
|
||||
.sorted_by_key(|(k, _)| *k)
|
||||
.map(|(k, v)| format!("\t\t{k} -> {v}"))
|
||||
.join(",\n")
|
||||
.as_str()
|
||||
)
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let pyid_to_def = self.pyid_to_def.read();
|
||||
str.push_str(
|
||||
format!(
|
||||
"\n\tpyid_to_def: {{{}}},",
|
||||
fmt_elems(
|
||||
pyid_to_def
|
||||
.iter()
|
||||
.sorted_by_key(|(k, _)| *k)
|
||||
.map(|(k, v)| {
|
||||
let tld_str = tld.map_or_else(
|
||||
|| format!("{v:?}"),
|
||||
|tlds| stringify_tld(&tlds[v.0]),
|
||||
);
|
||||
format!("\t\t{k} -> {tld_str}")
|
||||
})
|
||||
.join(",\n")
|
||||
.as_str()
|
||||
)
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let pyid_to_type = self.pyid_to_type.read();
|
||||
str.push_str(
|
||||
format!(
|
||||
"\n\tpyid_to_type: {{{}}},",
|
||||
fmt_elems(
|
||||
pyid_to_type
|
||||
.iter()
|
||||
.sorted_by_key(|(k, _)| *k)
|
||||
.map(|(k, v)| {
|
||||
let ty_str = unifier.as_ref().map_or_else(
|
||||
|| format!("{v:?}"),
|
||||
|unifier| unifier.stringify(*v),
|
||||
);
|
||||
format!("\t\t{k} -> {ty_str}")
|
||||
})
|
||||
.join(",\n")
|
||||
.as_str()
|
||||
)
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let string_store = self.string_store.read();
|
||||
str.push_str(
|
||||
format!(
|
||||
"\n\tstring_store: {{{}}},",
|
||||
fmt_elems(
|
||||
string_store
|
||||
.iter()
|
||||
.sorted_by_key(|(k, _)| *k)
|
||||
.map(|(k, v)| format!("\t\t{k} -> {v}"))
|
||||
.join(",\n")
|
||||
.as_str()
|
||||
)
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
|
||||
{
|
||||
let exception_ids = self.exception_ids.read();
|
||||
str.push_str(
|
||||
format!(
|
||||
"\n\texception_ids: {{{}}},",
|
||||
fmt_elems(
|
||||
exception_ids
|
||||
.iter()
|
||||
.sorted_by_key(|(k, _)| *k)
|
||||
.map(|(k, v)| format!("\t\t{k} -> {v}"))
|
||||
.join(",\n")
|
||||
.as_str()
|
||||
)
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
}
|
||||
|
||||
let name_to_pyid = &self.name_to_pyid;
|
||||
str.push_str(
|
||||
format!(
|
||||
"\n\tname_to_pyid: {{{}}},",
|
||||
fmt_elems(
|
||||
name_to_pyid
|
||||
.iter()
|
||||
.sorted_by_cached_key(|(k, _)| k.to_string())
|
||||
.map(|(k, v)| format!("\t\t{k} -> {v}"))
|
||||
.join(",\n")
|
||||
.as_str()
|
||||
)
|
||||
)
|
||||
.as_str(),
|
||||
);
|
||||
|
||||
let module = &self.module;
|
||||
str.push_str(format!("\n\tmodule: {module}").as_str());
|
||||
|
||||
str.push_str("\n}");
|
||||
|
||||
str
|
||||
}
|
||||
}
|
@ -10,9 +10,11 @@
|
||||
)]
|
||||
|
||||
use std::{
|
||||
cell::LazyCell,
|
||||
collections::{HashMap, HashSet},
|
||||
fs,
|
||||
io::Write,
|
||||
path::Path,
|
||||
process::Command,
|
||||
rc::Rc,
|
||||
sync::Arc,
|
||||
@ -66,9 +68,13 @@ use symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Reso
|
||||
use timeline::TimeFns;
|
||||
|
||||
mod codegen;
|
||||
mod debug;
|
||||
mod symbol_resolver;
|
||||
mod timeline;
|
||||
|
||||
const ENV_NAC3_EMIT_LLVM_BC: &str = "NAC3_EMIT_LLVM_BC";
|
||||
const ENV_NAC3_EMIT_LLVM_LL: &str = "NAC3_EMIT_LLVM_LL";
|
||||
|
||||
#[derive(PartialEq, Clone, Copy)]
|
||||
enum Isa {
|
||||
Host,
|
||||
@ -160,6 +166,8 @@ pub struct PrimitivePythonId {
|
||||
virtual_id: u64,
|
||||
option: u64,
|
||||
module: u64,
|
||||
kernel: u64,
|
||||
kernel_invariant: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
@ -228,6 +236,17 @@ impl Nac3 {
|
||||
let parser_result = parse_program(&source, source_file.into())
|
||||
.map_err(|e| exceptions::PySyntaxError::new_err(format!("parse error: {e}")))?;
|
||||
|
||||
let id_fn = LazyCell::new(|| {
|
||||
Python::with_gil(|py| {
|
||||
PyModule::import(py, "builtins").unwrap().getattr("id").unwrap().unbind()
|
||||
})
|
||||
});
|
||||
let get_type_hints_fn = LazyCell::new(|| {
|
||||
Python::with_gil(|py| {
|
||||
PyModule::import(py, "typing").unwrap().getattr("get_type_hints").unwrap().unbind()
|
||||
})
|
||||
});
|
||||
|
||||
for mut stmt in parser_result {
|
||||
let include = match stmt.node {
|
||||
StmtKind::ClassDef { ref decorator_list, ref mut body, ref mut bases, .. } => {
|
||||
@ -244,7 +263,6 @@ impl Nac3 {
|
||||
// Drop unregistered (i.e. host-only) base classes.
|
||||
bases.retain(|base| {
|
||||
Python::with_gil(|py| -> PyResult<bool> {
|
||||
let id_fn = PyModule::import(py, "builtins")?.getattr("id")?;
|
||||
match &base.node {
|
||||
ExprKind::Name { id, .. } => {
|
||||
if *id == "Exception".into() {
|
||||
@ -252,7 +270,8 @@ impl Nac3 {
|
||||
} else {
|
||||
let base_obj =
|
||||
module.bind(py).getattr(id.to_string().as_str())?;
|
||||
let base_id = id_fn.call1((base_obj,))?.extract()?;
|
||||
let base_id =
|
||||
id_fn.bind(py).call1((base_obj,))?.extract()?;
|
||||
Ok(registered_class_ids.contains(&base_id))
|
||||
}
|
||||
}
|
||||
@ -285,10 +304,28 @@ impl Nac3 {
|
||||
}
|
||||
})
|
||||
}
|
||||
// Allow global variable declaration with `Kernel` type annotation
|
||||
StmtKind::AnnAssign { ref annotation, .. } => {
|
||||
matches!(&annotation.node, ExprKind::Subscript { value, .. } if matches!(&value.node, ExprKind::Name {id, ..} if id == &"Kernel".into()))
|
||||
|
||||
// Allow global variable declaration with `Kernel` or `KernelInvariant` type annotation
|
||||
StmtKind::AnnAssign { ref target, .. } => match &target.node {
|
||||
ExprKind::Name { id, .. } => Python::with_gil(|py| {
|
||||
let py_type_hints =
|
||||
get_type_hints_fn.bind(py).call1((module.bind(py),)).unwrap();
|
||||
let py_type_hints = py_type_hints.downcast::<PyDict>().unwrap();
|
||||
let var_type_hint =
|
||||
py_type_hints.get_item(id.to_string().as_str()).unwrap().unwrap();
|
||||
let var_type = var_type_hint.getattr_opt("__origin__").unwrap();
|
||||
if let Some(var_type) = var_type {
|
||||
let var_type_id = id_fn.bind(py).call1((var_type,)).unwrap();
|
||||
let var_type_id = var_type_id.extract::<u64>().unwrap();
|
||||
|
||||
[self.primitive_ids.kernel, self.primitive_ids.kernel_invariant]
|
||||
.contains(&var_type_id)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}),
|
||||
_ => false,
|
||||
},
|
||||
_ => false,
|
||||
};
|
||||
|
||||
@ -568,7 +605,10 @@ impl Nac3 {
|
||||
py,
|
||||
(
|
||||
def_id.0.into_py_any(py)?,
|
||||
module.getattr(py, name.to_string().as_str()).unwrap(),
|
||||
module
|
||||
.bind(py)
|
||||
.getattr(name.to_string().as_str())
|
||||
.unwrap(),
|
||||
),
|
||||
)
|
||||
.unwrap();
|
||||
@ -593,7 +633,7 @@ impl Nac3 {
|
||||
}
|
||||
StmtKind::ClassDef { name, body, .. } => {
|
||||
let class_name = name.to_string();
|
||||
let class_obj = Arc::new(module.getattr(py, class_name.as_str()).unwrap());
|
||||
let class_obj = module.bind(py).getattr(class_name.as_str()).unwrap();
|
||||
for stmt in body {
|
||||
if let StmtKind::FunctionDef { name, decorator_list, .. } = &stmt.node {
|
||||
for decorator in decorator_list {
|
||||
@ -765,9 +805,7 @@ impl Nac3 {
|
||||
py,
|
||||
(
|
||||
id.0.into_py_any(py)?,
|
||||
class_def
|
||||
.getattr(py, name.to_string().as_str())
|
||||
.unwrap(),
|
||||
class_def.getattr(name.to_string().as_str()).unwrap(),
|
||||
),
|
||||
)
|
||||
.unwrap();
|
||||
@ -901,6 +939,18 @@ impl Nac3 {
|
||||
|
||||
embedding_map.setattr("expects_return", has_return).unwrap();
|
||||
|
||||
let emit_llvm_bc = std::env::var(ENV_NAC3_EMIT_LLVM_BC).is_ok();
|
||||
let emit_llvm_ll = std::env::var(ENV_NAC3_EMIT_LLVM_LL).is_ok();
|
||||
|
||||
let emit_llvm = |module: &Module<'_>, filename: &str| {
|
||||
if emit_llvm_bc {
|
||||
module.write_bitcode_to_path(Path::new(format!("{filename}.bc").as_str()));
|
||||
}
|
||||
if emit_llvm_ll {
|
||||
module.print_to_file(Path::new(format!("{filename}.ll").as_str())).unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
// Link all modules into `main`.
|
||||
let buffers = membuffers.lock();
|
||||
let main = context
|
||||
@ -909,6 +959,8 @@ impl Nac3 {
|
||||
"main",
|
||||
))
|
||||
.unwrap();
|
||||
emit_llvm(&main, "main");
|
||||
|
||||
for buffer in buffers.iter().rev().skip(1) {
|
||||
let other = context
|
||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
|
||||
@ -916,7 +968,10 @@ impl Nac3 {
|
||||
|
||||
main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||
}
|
||||
emit_llvm(&main, "main.merged");
|
||||
|
||||
main.link_in_module(irrt).map_err(|err| CompileError::new_err(err.to_string()))?;
|
||||
emit_llvm(&main, "main.fat");
|
||||
|
||||
let mut function_iter = main.get_first_function();
|
||||
while let Some(func) = function_iter {
|
||||
@ -936,6 +991,8 @@ impl Nac3 {
|
||||
global_option = global.get_next_global();
|
||||
}
|
||||
|
||||
emit_llvm(&main, "main.pre-opt");
|
||||
|
||||
let target_machine = self
|
||||
.llvm_options
|
||||
.target
|
||||
@ -950,6 +1007,8 @@ impl Nac3 {
|
||||
panic!("Failed to run optimization for module `main`: {}", err.to_string());
|
||||
}
|
||||
|
||||
emit_llvm(&main, "main.post-opt");
|
||||
|
||||
Python::with_gil(|py| {
|
||||
let string_store = self.string_store.read();
|
||||
let mut string_store_vec = string_store.iter().collect::<Vec<_>>();
|
||||
@ -1133,42 +1192,59 @@ impl Nac3 {
|
||||
|
||||
let builtins_mod = PyModule::import(py, "builtins").unwrap();
|
||||
let id_fn = builtins_mod.getattr("id").unwrap();
|
||||
let numpy_mod = PyModule::import(py, "numpy").unwrap();
|
||||
let typing_mod = PyModule::import(py, "typing").unwrap();
|
||||
let types_mod = PyModule::import(py, "types").unwrap();
|
||||
|
||||
let get_id = |x: &Bound<PyAny>| id_fn.call1((x,)).and_then(|id| id.extract()).unwrap();
|
||||
let get_attr_id = |obj: &Bound<PyModule>, attr| {
|
||||
id_fn.call1((obj.getattr(attr).unwrap(),)).unwrap().extract().unwrap()
|
||||
let get_artiq_builtin = |mod_name: Option<&str>, name: &str| -> Bound<PyAny> {
|
||||
if let Some(mod_name) = mod_name {
|
||||
artiq_builtins
|
||||
.get_item(mod_name)
|
||||
.unwrap()
|
||||
.unwrap_or_else(|| {
|
||||
panic!("no module key '{mod_name}' present in artiq_builtins")
|
||||
})
|
||||
.downcast::<PyDict>()
|
||||
.unwrap()
|
||||
.get_item(name)
|
||||
.unwrap()
|
||||
.unwrap_or_else(|| {
|
||||
panic!("no key '{name}' present in artiq_builtins.{mod_name}")
|
||||
})
|
||||
} else {
|
||||
artiq_builtins
|
||||
.get_item(name)
|
||||
.unwrap()
|
||||
.unwrap_or_else(|| panic!("no key '{name}' present in artiq_builtins"))
|
||||
}
|
||||
};
|
||||
|
||||
let primitive_ids = PrimitivePythonId {
|
||||
virtual_id: get_id(&artiq_builtins.get_item("virtual").ok().flatten().unwrap()),
|
||||
virtual_id: get_id(&get_artiq_builtin(Some("artiq"), "virtual")),
|
||||
generic_alias: (
|
||||
get_attr_id(&typing_mod, "_GenericAlias"),
|
||||
get_attr_id(&types_mod, "GenericAlias"),
|
||||
get_id(&get_artiq_builtin(Some("typing"), "_GenericAlias")),
|
||||
get_id(&get_artiq_builtin(Some("types"), "GenericAlias")),
|
||||
),
|
||||
none: get_id(&artiq_builtins.get_item("none").ok().flatten().unwrap()),
|
||||
typevar: get_attr_id(&typing_mod, "TypeVar"),
|
||||
const_generic_marker: get_id(
|
||||
&artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap(),
|
||||
),
|
||||
int: get_attr_id(&builtins_mod, "int"),
|
||||
int32: get_attr_id(&numpy_mod, "int32"),
|
||||
int64: get_attr_id(&numpy_mod, "int64"),
|
||||
uint32: get_attr_id(&numpy_mod, "uint32"),
|
||||
uint64: get_attr_id(&numpy_mod, "uint64"),
|
||||
bool: get_attr_id(&builtins_mod, "bool"),
|
||||
np_bool_: get_attr_id(&numpy_mod, "bool_"),
|
||||
string: get_attr_id(&builtins_mod, "str"),
|
||||
np_str_: get_attr_id(&numpy_mod, "str_"),
|
||||
float: get_attr_id(&builtins_mod, "float"),
|
||||
float64: get_attr_id(&numpy_mod, "float64"),
|
||||
list: get_attr_id(&builtins_mod, "list"),
|
||||
ndarray: get_attr_id(&numpy_mod, "ndarray"),
|
||||
tuple: get_attr_id(&builtins_mod, "tuple"),
|
||||
exception: get_attr_id(&builtins_mod, "Exception"),
|
||||
option: get_id(&artiq_builtins.get_item("Option").ok().flatten().unwrap()),
|
||||
module: get_attr_id(&types_mod, "ModuleType"),
|
||||
none: get_id(&get_artiq_builtin(Some("artiq"), "none")),
|
||||
typevar: get_id(&get_artiq_builtin(Some("typing"), "TypeVar")),
|
||||
const_generic_marker: get_id(&get_artiq_builtin(Some("artiq"), "_ConstGenericMarker")),
|
||||
int: get_id(&get_artiq_builtin(None, "int")),
|
||||
int32: get_id(&get_artiq_builtin(Some("numpy"), "int32")),
|
||||
int64: get_id(&get_artiq_builtin(Some("numpy"), "int64")),
|
||||
uint32: get_id(&get_artiq_builtin(Some("numpy"), "uint32")),
|
||||
uint64: get_id(&get_artiq_builtin(Some("numpy"), "uint64")),
|
||||
bool: get_id(&get_artiq_builtin(None, "bool")),
|
||||
np_bool_: get_id(&get_artiq_builtin(Some("numpy"), "bool_")),
|
||||
string: get_id(&get_artiq_builtin(None, "str")),
|
||||
np_str_: get_id(&get_artiq_builtin(Some("numpy"), "str_")),
|
||||
float: get_id(&get_artiq_builtin(None, "float")),
|
||||
float64: get_id(&get_artiq_builtin(Some("numpy"), "float64")),
|
||||
list: get_id(&get_artiq_builtin(None, "list")),
|
||||
ndarray: get_id(&get_artiq_builtin(Some("numpy"), "ndarray")),
|
||||
tuple: get_id(&get_artiq_builtin(None, "tuple")),
|
||||
exception: get_id(&get_artiq_builtin(None, "Exception")),
|
||||
option: get_id(&get_artiq_builtin(Some("artiq"), "Option")),
|
||||
module: get_id(&get_artiq_builtin(Some("types"), "ModuleType")),
|
||||
kernel: get_id(&get_artiq_builtin(Some("artiq"), "Kernel")),
|
||||
kernel_invariant: get_id(&get_artiq_builtin(Some("artiq"), "KernelInvariant")),
|
||||
};
|
||||
|
||||
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
|
||||
@ -1314,9 +1390,8 @@ impl Nac3 {
|
||||
py: Python<'py>,
|
||||
) -> PyResult<()> {
|
||||
let target_machine = self.get_llvm_target_machine();
|
||||
|
||||
if self.isa == Isa::Host {
|
||||
let link_fn = |module: &Module| {
|
||||
if self.isa == Isa::Host {
|
||||
let working_directory = self.working_directory.path().to_owned();
|
||||
target_machine
|
||||
.write_to_file(module, FileType::Object, &working_directory.join("module.o"))
|
||||
@ -1326,11 +1401,7 @@ impl Nac3 {
|
||||
working_directory.join("module.o").to_string_lossy().to_string(),
|
||||
)?;
|
||||
Ok(())
|
||||
};
|
||||
|
||||
self.compile_method(obj, method_name, args, embedding_map, py, &link_fn)
|
||||
} else {
|
||||
let link_fn = |module: &Module| {
|
||||
let object_mem = target_machine
|
||||
.write_to_memory_buffer(module, FileType::Object)
|
||||
.expect("couldn't write module to object file buffer");
|
||||
@ -1344,11 +1415,11 @@ impl Nac3 {
|
||||
} else {
|
||||
Err(CompileError::new_err("linker failed to process object file"))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
self.compile_method(obj, method_name, args, embedding_map, py, &link_fn)
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_method_to_mem<'py>(
|
||||
&mut self,
|
||||
@ -1359,9 +1430,8 @@ impl Nac3 {
|
||||
py: Python<'py>,
|
||||
) -> PyResult<PyObject> {
|
||||
let target_machine = self.get_llvm_target_machine();
|
||||
|
||||
if self.isa == Isa::Host {
|
||||
let link_fn = |module: &Module| {
|
||||
if self.isa == Isa::Host {
|
||||
let working_directory = self.working_directory.path().to_owned();
|
||||
target_machine
|
||||
.write_to_file(module, FileType::Object, &working_directory.join("module.o"))
|
||||
@ -1375,11 +1445,7 @@ impl Nac3 {
|
||||
)?;
|
||||
|
||||
Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())
|
||||
};
|
||||
|
||||
self.compile_method(obj, method_name, args, embedding_map, py, &link_fn)
|
||||
} else {
|
||||
let link_fn = |module: &Module| {
|
||||
let object_mem = target_machine
|
||||
.write_to_memory_buffer(module, FileType::Object)
|
||||
.expect("couldn't write module to object file buffer");
|
||||
@ -1388,12 +1454,12 @@ impl Nac3 {
|
||||
} else {
|
||||
Err(CompileError::new_err("linker failed to process object file"))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
self.compile_method(obj, method_name, args, embedding_map, py, &link_fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "init-llvm-profile")]
|
||||
unsafe extern "C" {
|
||||
|
@ -1,5 +1,6 @@
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
fmt::Debug,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicBool, Ordering::Relaxed},
|
||||
@ -41,6 +42,7 @@ use nac3core::{
|
||||
|
||||
use super::PrimitivePythonId;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum PrimitiveValue {
|
||||
I32(i32),
|
||||
I64(i64),
|
||||
@ -73,10 +75,10 @@ impl DeferredEvaluationStore {
|
||||
|
||||
/// A class field as stored in the [`InnerResolver`], represented by the ID and name of the
|
||||
/// associated [`PythonValue`].
|
||||
type ResolverField = (u64, StrRef);
|
||||
pub(crate) type ResolverField = (u64, StrRef);
|
||||
|
||||
/// A value as stored in Python, represented by the `id()` and [`PyObject`] of the value.
|
||||
type PyValueHandle = (u64, Arc<PyObject>);
|
||||
pub(crate) type PyValueHandle = (u64, Arc<PyObject>);
|
||||
|
||||
pub struct InnerResolver {
|
||||
pub id_to_type: RwLock<HashMap<StrRef, Type>>,
|
||||
@ -97,6 +99,13 @@ pub struct InnerResolver {
|
||||
pub module: Arc<PyObject>,
|
||||
}
|
||||
|
||||
impl Debug for InnerResolver {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.debug_str(None, &None))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Resolver(pub Arc<InnerResolver>);
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -197,19 +206,25 @@ impl StaticValue for PythonValue {
|
||||
.unwrap_or_else(|| {
|
||||
Python::with_gil(|py| -> PyResult<Option<PyValueHandle>> {
|
||||
let helper = &self.resolver.helper;
|
||||
let id = helper.id_fn.bind(py).call1((&*self.value,))?.extract::<u64>()?;
|
||||
let ty = helper.type_fn.bind(py).call1((&*self.value,))?;
|
||||
let ty_id: u64 = helper.id_fn.bind(py).call1((ty,))?.extract()?;
|
||||
|
||||
// for optimizing unwrap KernelInvariant
|
||||
if ty_id == self.resolver.primitive_ids.option && name == "_nac3_option".into() {
|
||||
let obj = Arc::new(self.value.getattr(py, name.to_string().as_str())?);
|
||||
let id = self.resolver.helper.id_fn.bind(py).call1((&*obj,))?.extract()?;
|
||||
let obj = self.value.bind(py).getattr(name.to_string().as_str())?;
|
||||
let id = self.resolver.helper.id_fn.bind(py).call1((&obj,))?.extract()?;
|
||||
let obj = Arc::new(obj.into_py_any(py)?);
|
||||
return if self.id == self.resolver.primitive_ids.none {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some((id, obj)))
|
||||
};
|
||||
}
|
||||
let def_id = { *self.resolver.pyid_to_def.read().get(&ty_id).unwrap() };
|
||||
|
||||
let result = if let Some(def_id) =
|
||||
self.resolver.pyid_to_def.read().get(&ty_id).copied()
|
||||
{
|
||||
let mut mutable = true;
|
||||
let defs = ctx.top_level.definitions.read();
|
||||
if let TopLevelDef::Class { fields, .. } = &*defs[def_id.0].read() {
|
||||
@ -220,13 +235,36 @@ impl StaticValue for PythonValue {
|
||||
}
|
||||
}
|
||||
}
|
||||
let result = if mutable {
|
||||
|
||||
if mutable {
|
||||
None
|
||||
} else {
|
||||
let obj = Arc::new(self.value.getattr(py, name.to_string().as_str())?);
|
||||
let id = self.resolver.helper.id_fn.bind(py).call1((&*obj,))?.extract()?;
|
||||
let obj = self.value.bind(py).getattr(name.to_string().as_str())?;
|
||||
let id = self.resolver.helper.id_fn.bind(py).call1((&obj,))?.extract()?;
|
||||
let obj = Arc::new(obj.into_py_any(py)?);
|
||||
Some((id, obj))
|
||||
}
|
||||
} else if let Some(def_id) = self.resolver.pyid_to_def.read().get(&id).copied() {
|
||||
// Check if self.value is a module
|
||||
let in_mod_ctx = ctx
|
||||
.top_level
|
||||
.definitions
|
||||
.read()
|
||||
.get(def_id.0)
|
||||
.is_some_and(|def| matches!(&*def.read(), TopLevelDef::Module { .. }));
|
||||
|
||||
if in_mod_ctx {
|
||||
let obj = self.value.bind(py).getattr(name.to_string().as_str())?;
|
||||
let id = self.resolver.helper.id_fn.bind(py).call1((&obj,))?.extract()?;
|
||||
let obj = Arc::new(obj.into_py_any(py)?);
|
||||
Some((id, obj))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
self.resolver.field_to_val.write().insert((self.id, name), result.clone());
|
||||
Ok(result)
|
||||
})
|
||||
@ -1690,9 +1728,10 @@ impl SymbolResolver for Resolver {
|
||||
) -> Option<ValueEnum<'ctx>> {
|
||||
if let Some(def_id) = self.0.id_to_def.read().get(&id) {
|
||||
let top_levels = ctx.top_level.definitions.read();
|
||||
if matches!(&*top_levels[def_id.0].read(), TopLevelDef::Variable { .. }) {
|
||||
if let TopLevelDef::Variable { resolver, .. } = &*top_levels[def_id.0].read() {
|
||||
let module_val = &self.0.module;
|
||||
let ret = Python::with_gil(|py| -> PyResult<Result<BasicValueEnum, String>> {
|
||||
let Ok((obj, idx)) = Python::with_gil(
|
||||
|py| -> PyResult<Result<(BasicValueEnum<'ctx>, Option<usize>), String>> {
|
||||
let module_val = (**module_val).bind(py);
|
||||
|
||||
let ty = self.0.get_obj_type(
|
||||
@ -1706,8 +1745,31 @@ impl SymbolResolver for Resolver {
|
||||
return Ok(Err(ty));
|
||||
}
|
||||
let ty = ty.unwrap();
|
||||
let obj = self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap();
|
||||
let obj =
|
||||
self.0.get_obj_value(py, module_val, ctx, generator, ty)?.unwrap();
|
||||
let (idx, _) = ctx.get_attr_index(ty, id);
|
||||
|
||||
Ok(Ok((obj, idx)))
|
||||
},
|
||||
)
|
||||
.unwrap() else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let Some(idx) = idx else {
|
||||
// `idx` not found in the current resolver - try the resolver of the variable
|
||||
return resolver.as_ref().and_then(|resolver| {
|
||||
let resolver = &**resolver;
|
||||
|
||||
// TODO: Can we assume that if get_identifier_def returns a result,
|
||||
// get_symbol_value will also return a value?
|
||||
resolver
|
||||
.get_identifier_def(id)
|
||||
.ok()
|
||||
.and_then(|_| resolver.get_symbol_value(id, ctx, generator))
|
||||
});
|
||||
};
|
||||
|
||||
let ret = unsafe {
|
||||
ctx.builder.build_gep(
|
||||
obj.into_pointer_value(),
|
||||
@ -1719,13 +1781,7 @@ impl SymbolResolver for Resolver {
|
||||
)
|
||||
}
|
||||
.unwrap();
|
||||
Ok(Ok(ret.as_basic_value_enum()))
|
||||
})
|
||||
.unwrap();
|
||||
if ret.is_err() {
|
||||
return None;
|
||||
}
|
||||
return Some(ret.unwrap().into());
|
||||
return Some(ret.as_basic_value_enum().into());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,7 +124,7 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
|
||||
|
||||
/// Checks the field and attributes of classes
|
||||
/// Returns the index of attr in class fields otherwise returns the attribute value
|
||||
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option<Constant>) {
|
||||
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (Option<usize>, Option<Constant>) {
|
||||
let obj_id = match &*self.unifier.get_ty(ty) {
|
||||
TypeEnum::TObj { obj_id, .. } => *obj_id,
|
||||
TypeEnum::TModule { module_id, .. } => *module_id,
|
||||
@ -134,13 +134,16 @@ impl<'ctx> CodeGenContext<'ctx, '_> {
|
||||
let def = &self.top_level.definitions.read()[obj_id.0];
|
||||
let (index, value) = if let TopLevelDef::Class { fields, attributes, .. } = &*def.read() {
|
||||
if let Some(field_index) = fields.iter().find_position(|x| x.0 == attr) {
|
||||
(field_index.0, None)
|
||||
(Some(field_index.0), None)
|
||||
} else {
|
||||
let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap();
|
||||
(attribute_index.0, Some(attribute_index.1.2.clone()))
|
||||
let attribute_index = attributes.iter().find_position(|x| x.0 == attr);
|
||||
(
|
||||
attribute_index.map(|(idx, _)| idx),
|
||||
attribute_index.map(|(_, (_, _, k))| k.clone()),
|
||||
)
|
||||
}
|
||||
} else if let TopLevelDef::Module { attributes, .. } = &*def.read() {
|
||||
(attributes.iter().find_position(|x| x.0 == attr).unwrap().0, None)
|
||||
(attributes.iter().find_position(|x| x.0 == attr).map(|(idx, _)| idx), None)
|
||||
} else {
|
||||
codegen_unreachable!(self)
|
||||
};
|
||||
@ -2461,7 +2464,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr);
|
||||
Ok(ValueEnum::Dynamic(ctx.build_gep_and_load(
|
||||
v.into_pointer_value(),
|
||||
&[zero, int32.const_int(index as u64, false)],
|
||||
&[zero, int32.const_int(index.unwrap() as u64, false)],
|
||||
None,
|
||||
))) as Result<_, String>
|
||||
},
|
||||
@ -2478,7 +2481,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||
}
|
||||
ValueEnum::Dynamic(ctx.build_gep_and_load(
|
||||
v.into_pointer_value(),
|
||||
&[zero, int32.const_int(index as u64, false)],
|
||||
&[zero, int32.const_int(index.unwrap() as u64, false)],
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
@ -133,7 +133,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||
ptr,
|
||||
&[
|
||||
ctx.ctx.i32_type().const_zero(),
|
||||
ctx.ctx.i32_type().const_int(index as u64, false),
|
||||
ctx.ctx.i32_type().const_int(index.unwrap() as u64, false),
|
||||
],
|
||||
name.unwrap_or(""),
|
||||
)
|
||||
|
@ -463,9 +463,9 @@ impl TopLevelComposer {
|
||||
|
||||
/// Registers a top-level variable with the given `name` into the composer.
|
||||
///
|
||||
/// `annotation` - The type annotation of the top-level variable, or [`None`] if no type
|
||||
/// - `annotation` - The type annotation of the top-level variable, or [`None`] if no type
|
||||
/// annotation is provided.
|
||||
/// `location` - The location of the top-level variable.
|
||||
/// - `location` - The location of the top-level variable.
|
||||
pub fn register_top_level_var(
|
||||
&mut self,
|
||||
name: Ident,
|
||||
@ -1999,13 +1999,15 @@ impl TopLevelComposer {
|
||||
ExprKind::Subscript { value, slice, .. }
|
||||
if matches!(
|
||||
&value.node,
|
||||
ast::ExprKind::Name { id, .. } if self.core_config.kernel_ann.is_some_and(|c| id == &c.into())
|
||||
ast::ExprKind::Name { id, .. } if self.core_config.kernel_ann.is_some_and(|c| id == &c.into()) || id == &self.core_config.kernel_invariant_ann.into()
|
||||
) =>
|
||||
{
|
||||
slice
|
||||
}
|
||||
_ if self.core_config.kernel_ann.is_none() => ty_decl,
|
||||
_ => unreachable!("Global variables should be annotated with Kernel[]"), // ignore fields annotated otherwise
|
||||
_ => unreachable!(
|
||||
"Global variables should be annotated with Kernel[] or KernelInvariant[]"
|
||||
), // ignore fields annotated otherwise
|
||||
};
|
||||
|
||||
let ty_annotation = parse_ast_to_type_annotation_kinds(
|
||||
|
@ -60,126 +60,50 @@ impl TypeAnnotation {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses an AST expression `expr` into a [`TypeAnnotation`].
|
||||
///
|
||||
/// * `locked` - A [`HashMap`] containing the IDs of known definitions, mapped to a [`Vec`] of all
|
||||
/// generic variables associated with the definition.
|
||||
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
|
||||
/// [`None`] when this function is invoked externally.
|
||||
pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
||||
/// Converts a [`DefinitionId`] representing a [`TopLevelDef::Class`] and its type arguments into a
|
||||
/// [`TypeAnnotation`].
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn class_def_id_to_type_annotation<T, S: std::hash::BuildHasher + Clone>(
|
||||
resolver: &(dyn SymbolResolver + Send + Sync),
|
||||
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
expr: &ast::Expr<T>,
|
||||
// the key stores the type_var of this topleveldef::class, we only need this field here
|
||||
locked: HashMap<DefinitionId, Vec<Type>, S>,
|
||||
mut locked: HashMap<DefinitionId, Vec<Type>, S>,
|
||||
id: StrRef,
|
||||
(obj_id, type_args): (DefinitionId, Option<&Expr<T>>),
|
||||
location: &Location,
|
||||
) -> Result<TypeAnnotation, HashSet<String>> {
|
||||
let name_handle = |id: &StrRef,
|
||||
unifier: &mut Unifier,
|
||||
locked: HashMap<DefinitionId, Vec<Type>, S>| {
|
||||
if id == &"int32".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.int32))
|
||||
} else if id == &"int64".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.int64))
|
||||
} else if id == &"uint32".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.uint32))
|
||||
} else if id == &"uint64".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.uint64))
|
||||
} else if id == &"float".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.float))
|
||||
} else if id == &"bool".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.bool))
|
||||
} else if id == &"str".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.str))
|
||||
} else if id == &"Exception".into() {
|
||||
Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
|
||||
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
|
||||
let type_vars = {
|
||||
let Some(top_level_def) = top_level_defs.get(obj_id.0) else {
|
||||
return Err(HashSet::from([format!(
|
||||
"NameError: name '{id}' is not defined (at {})",
|
||||
expr.location
|
||||
"NameError: name '{id}' is not defined (at {location})",
|
||||
)]));
|
||||
};
|
||||
let def_read = top_level_def.try_read();
|
||||
if let Some(def_read) = def_read {
|
||||
|
||||
// We need to use `try_read` here, since the composer may be processing our class right now,
|
||||
// which requires exclusive access to modify the class internals.
|
||||
//
|
||||
// `locked` is guaranteed to hold a k-v pair of the composer-processing class, so fallback
|
||||
// to it if the `top_level_def` is already locked for mutation.
|
||||
let type_vars = if let Some(def_read) = top_level_def.try_read() {
|
||||
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
|
||||
type_vars.clone()
|
||||
} else {
|
||||
return Err(HashSet::from([format!(
|
||||
"function cannot be used as a type (at {})",
|
||||
expr.location
|
||||
"function cannot be used as a type (at {location})",
|
||||
)]));
|
||||
}
|
||||
} else {
|
||||
locked.get(&obj_id).unwrap().clone()
|
||||
}
|
||||
};
|
||||
// check param number here
|
||||
if !type_vars.is_empty() {
|
||||
return Err(HashSet::from([format!(
|
||||
"expect {} type variable parameter but got 0 (at {})",
|
||||
type_vars.len(),
|
||||
expr.location,
|
||||
)]));
|
||||
}
|
||||
Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] })
|
||||
} else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) {
|
||||
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
|
||||
let var = unifier.get_fresh_var(Some(*id), Some(expr.location)).ty;
|
||||
unifier.unify(var, ty).unwrap();
|
||||
Ok(TypeAnnotation::TypeVar(ty))
|
||||
} else {
|
||||
Err(HashSet::from([format!(
|
||||
"`{}` is not a valid type annotation (at {})",
|
||||
id, expr.location
|
||||
)]))
|
||||
}
|
||||
} else {
|
||||
Err(HashSet::from([format!(
|
||||
"`{}` is not a valid type annotation (at {})",
|
||||
id, expr.location
|
||||
)]))
|
||||
}
|
||||
};
|
||||
|
||||
let class_name_handle =
|
||||
|id: &StrRef,
|
||||
slice: &ast::Expr<T>,
|
||||
unifier: &mut Unifier,
|
||||
mut locked: HashMap<DefinitionId, Vec<Type>, S>| {
|
||||
if ["virtual".into(), "Generic".into(), "tuple".into(), "Option".into()].contains(id) {
|
||||
return Err(HashSet::from([format!(
|
||||
"keywords cannot be class name (at {})",
|
||||
expr.location
|
||||
)]));
|
||||
}
|
||||
let obj_id = resolver.get_identifier_def(*id)?;
|
||||
let type_vars = {
|
||||
let Some(top_level_def) = top_level_defs.get(obj_id.0) else {
|
||||
return Err(HashSet::from([format!(
|
||||
"NameError: name '{id}' is not defined (at {})",
|
||||
expr.location
|
||||
)]));
|
||||
};
|
||||
let def_read = top_level_def.try_read();
|
||||
if let Some(def_read) = def_read {
|
||||
let TopLevelDef::Class { type_vars, .. } = &*def_read else {
|
||||
unreachable!("must be class here")
|
||||
};
|
||||
type_vars.clone()
|
||||
} else {
|
||||
locked.get(&obj_id).unwrap().clone()
|
||||
}
|
||||
};
|
||||
let param_type_infos = if let Some(slice) = type_args {
|
||||
// we do not check whether the application of type variables are compatible here
|
||||
let param_type_infos = {
|
||||
let params_ast = if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
|
||||
elts.iter().collect_vec()
|
||||
} else {
|
||||
vec![slice]
|
||||
};
|
||||
|
||||
if type_vars.len() != params_ast.len() {
|
||||
return Err(HashSet::from([format!(
|
||||
"expect {} type parameters but got {} (at {})",
|
||||
@ -188,6 +112,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
||||
params_ast[0].location,
|
||||
)]));
|
||||
}
|
||||
|
||||
let result = params_ast
|
||||
.iter()
|
||||
.map(|x| {
|
||||
@ -204,6 +129,7 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
||||
)
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// make sure the result do not contain any type vars
|
||||
let no_type_var =
|
||||
result.iter().all(|x| get_type_var_contained_in_type_annotation(x).is_empty());
|
||||
@ -215,12 +141,127 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
||||
params_ast[0].location
|
||||
)]));
|
||||
}
|
||||
};
|
||||
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
|
||||
} else {
|
||||
// check param number here
|
||||
if !type_vars.is_empty() {
|
||||
return Err(HashSet::from([format!(
|
||||
"expect {} type variable parameter but got 0 (at {location})",
|
||||
type_vars.len(),
|
||||
)]));
|
||||
}
|
||||
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
|
||||
}
|
||||
|
||||
/// Parses the `id` of a [`ast::ExprKind::Name`] expression as a [`TypeAnnotation`].
|
||||
fn parse_name_as_type_annotation<T, S: std::hash::BuildHasher + Clone>(
|
||||
resolver: &(dyn SymbolResolver + Send + Sync),
|
||||
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
locked: HashMap<DefinitionId, Vec<Type>, S>,
|
||||
id: StrRef,
|
||||
location: &Location,
|
||||
) -> Result<TypeAnnotation, HashSet<String>> {
|
||||
if id == "int32".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.int32))
|
||||
} else if id == "int64".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.int64))
|
||||
} else if id == "uint32".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.uint32))
|
||||
} else if id == "uint64".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.uint64))
|
||||
} else if id == "float".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.float))
|
||||
} else if id == "bool".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.bool))
|
||||
} else if id == "str".into() {
|
||||
Ok(TypeAnnotation::Primitive(primitives.str))
|
||||
} else if id == "Exception".into() {
|
||||
Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
|
||||
} else if let Ok(obj_id) = resolver.get_identifier_def(id) {
|
||||
class_def_id_to_type_annotation(
|
||||
resolver,
|
||||
top_level_defs,
|
||||
unifier,
|
||||
primitives,
|
||||
locked,
|
||||
id,
|
||||
(obj_id, None as Option<&Expr<T>>),
|
||||
location,
|
||||
)
|
||||
} else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, id) {
|
||||
if let TypeEnum::TVar { .. } = unifier.get_ty(ty).as_ref() {
|
||||
let var = unifier.get_fresh_var(Some(id), Some(*location)).ty;
|
||||
unifier.unify(var, ty).unwrap();
|
||||
Ok(TypeAnnotation::TypeVar(ty))
|
||||
} else {
|
||||
Err(HashSet::from([format!("`{id}` is not a valid type annotation (at {location})",)]))
|
||||
}
|
||||
} else {
|
||||
Err(HashSet::from([format!("`{id}` is not a valid type annotation (at {location})",)]))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses the `id` and generic arguments of a class as a [`TypeAnnotation`].
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn parse_class_id_as_type_annotation<T, S: std::hash::BuildHasher + Clone>(
|
||||
resolver: &(dyn SymbolResolver + Send + Sync),
|
||||
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
locked: HashMap<DefinitionId, Vec<Type>, S>,
|
||||
id: StrRef,
|
||||
slice: &Expr<T>,
|
||||
location: &Location,
|
||||
) -> Result<TypeAnnotation, HashSet<String>> {
|
||||
if ["virtual".into(), "Generic".into(), "tuple".into(), "Option".into()].contains(&id) {
|
||||
return Err(HashSet::from([format!("keywords cannot be class name (at {location})")]));
|
||||
}
|
||||
|
||||
let obj_id = resolver.get_identifier_def(id)?;
|
||||
|
||||
class_def_id_to_type_annotation(
|
||||
resolver,
|
||||
top_level_defs,
|
||||
unifier,
|
||||
primitives,
|
||||
locked,
|
||||
id,
|
||||
(obj_id, Some(slice)),
|
||||
location,
|
||||
)
|
||||
}
|
||||
|
||||
/// Parses an AST expression `expr` into a [`TypeAnnotation`].
|
||||
///
|
||||
/// * `locked` - A [`HashMap`] containing the IDs of known definitions, mapped to a [`Vec`] of all
|
||||
/// generic variables associated with the definition.
|
||||
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
|
||||
/// [`None`] when this function is invoked externally.
|
||||
pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
||||
resolver: &(dyn SymbolResolver + Send + Sync),
|
||||
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
expr: &ast::Expr<T>,
|
||||
// the key stores the type_var of this topleveldef::class, we only need this field here
|
||||
locked: HashMap<DefinitionId, Vec<Type>, S>,
|
||||
) -> Result<TypeAnnotation, HashSet<String>> {
|
||||
match &expr.node {
|
||||
ast::ExprKind::Name { id, .. } => name_handle(id, unifier, locked),
|
||||
ast::ExprKind::Name { id, .. } => parse_name_as_type_annotation::<T, S>(
|
||||
resolver,
|
||||
top_level_defs,
|
||||
unifier,
|
||||
primitives,
|
||||
locked,
|
||||
*id,
|
||||
&expr.location,
|
||||
),
|
||||
|
||||
// virtual
|
||||
ast::ExprKind::Subscript { value, slice, .. }
|
||||
if {
|
||||
@ -341,9 +382,54 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
||||
|
||||
// custom class
|
||||
ast::ExprKind::Subscript { value, slice, .. } => {
|
||||
match &value.node {
|
||||
ast::ExprKind::Name { id, .. } => parse_class_id_as_type_annotation(
|
||||
resolver,
|
||||
top_level_defs,
|
||||
unifier,
|
||||
primitives,
|
||||
locked,
|
||||
*id,
|
||||
slice,
|
||||
&expr.location,
|
||||
),
|
||||
|
||||
ast::ExprKind::Attribute { value, attr, .. } => {
|
||||
if let ast::ExprKind::Name { id, .. } = &value.node {
|
||||
class_name_handle(id, slice, unifier, locked)
|
||||
let mod_id = resolver.get_identifier_def(*id)?;
|
||||
let Some(mod_tld) = top_level_defs.get(mod_id.0) else {
|
||||
return Err(HashSet::from([format!(
|
||||
"NameError: name '{id}' is not defined (at {})",
|
||||
expr.location
|
||||
)]));
|
||||
};
|
||||
|
||||
let matching_attr =
|
||||
if let TopLevelDef::Module { methods, .. } = &*mod_tld.read() {
|
||||
methods.get(attr).copied()
|
||||
} else {
|
||||
unreachable!("must be module here")
|
||||
};
|
||||
|
||||
let Some(def_id) = matching_attr else {
|
||||
return Err(HashSet::from([format!(
|
||||
"AttributeError: module '{id}' has no attribute '{attr}' (at {})",
|
||||
expr.location
|
||||
)]));
|
||||
};
|
||||
|
||||
class_def_id_to_type_annotation::<T, S>(
|
||||
resolver,
|
||||
top_level_defs,
|
||||
unifier,
|
||||
primitives,
|
||||
locked,
|
||||
*attr,
|
||||
(def_id, Some(slice)),
|
||||
&expr.location,
|
||||
)
|
||||
} else {
|
||||
// TODO: Handle multiple indirection
|
||||
Err(HashSet::from([format!(
|
||||
"unsupported expression type for class name (at {})",
|
||||
value.location
|
||||
@ -351,8 +437,57 @@ pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
|
||||
}
|
||||
}
|
||||
|
||||
_ => Err(HashSet::from([format!(
|
||||
"unsupported expression type for class name (at {})",
|
||||
value.location
|
||||
)])),
|
||||
}
|
||||
}
|
||||
|
||||
ast::ExprKind::Constant { value, .. } => Ok(TypeAnnotation::Literal(vec![value.clone()])),
|
||||
|
||||
ast::ExprKind::Attribute { value, attr, .. } => {
|
||||
if let ast::ExprKind::Name { id, .. } = &value.node {
|
||||
let mod_id = resolver.get_identifier_def(*id)?;
|
||||
let Some(mod_tld) = top_level_defs.get(mod_id.0) else {
|
||||
return Err(HashSet::from([format!(
|
||||
"NameError: name '{id}' is not defined (at {})",
|
||||
expr.location
|
||||
)]));
|
||||
};
|
||||
|
||||
let matching_attr = if let TopLevelDef::Module { methods, .. } = &*mod_tld.read() {
|
||||
methods.get(attr).copied()
|
||||
} else {
|
||||
unreachable!("must be module here")
|
||||
};
|
||||
|
||||
let Some(def_id) = matching_attr else {
|
||||
return Err(HashSet::from([format!(
|
||||
"AttributeError: module '{id}' has no attribute '{attr}' (at {})",
|
||||
expr.location
|
||||
)]));
|
||||
};
|
||||
|
||||
class_def_id_to_type_annotation::<T, S>(
|
||||
resolver,
|
||||
top_level_defs,
|
||||
unifier,
|
||||
primitives,
|
||||
locked,
|
||||
*attr,
|
||||
(def_id, None),
|
||||
&expr.location,
|
||||
)
|
||||
} else {
|
||||
// TODO: Handle multiple indirection
|
||||
Err(HashSet::from([format!(
|
||||
"unsupported expression type for class name (at {})",
|
||||
value.location
|
||||
)]))
|
||||
}
|
||||
}
|
||||
|
||||
_ => Err(HashSet::from([format!(
|
||||
"unsupported expression for type annotation (at {})",
|
||||
expr.location
|
||||
|
@ -19,8 +19,12 @@ use nac3core::{
|
||||
WithCall, WorkerRegistry, concrete_type::ConcreteTypeStore, irrt::load_irrt,
|
||||
},
|
||||
inkwell::{
|
||||
OptimizationLevel, memory_buffer::MemoryBuffer, module::Linkage,
|
||||
passes::PassBuilderOptions, support::is_multithreaded, targets::*,
|
||||
OptimizationLevel,
|
||||
memory_buffer::MemoryBuffer,
|
||||
module::{Linkage, Module},
|
||||
passes::PassBuilderOptions,
|
||||
support::is_multithreaded,
|
||||
targets::*,
|
||||
},
|
||||
nac3parser::{
|
||||
ast::{Constant, Expr, ExprKind, StmtKind, StrRef},
|
||||
@ -59,11 +63,13 @@ struct CommandLineArgs {
|
||||
#[arg(short = 'O', default_value_t = 2, value_parser = clap::value_parser!(u32).range(0..=3))]
|
||||
opt_level: u32,
|
||||
|
||||
/// Whether to emit LLVM IR at the end of every module.
|
||||
///
|
||||
/// If multithreaded compilation is also enabled, each thread will emit its own module.
|
||||
/// Whether to emit LLVM bitcode at the end of every module.
|
||||
#[arg(long, default_value_t = false)]
|
||||
emit_llvm: bool,
|
||||
emit_llvm_bc: bool,
|
||||
|
||||
/// Whether to emit LLVM IR text at the end of every module.
|
||||
#[arg(long, default_value_t = false)]
|
||||
emit_llvm_ir: bool,
|
||||
|
||||
/// The target triple to compile for.
|
||||
#[arg(long)]
|
||||
@ -276,8 +282,16 @@ fn handle_global_var(
|
||||
|
||||
fn main() {
|
||||
let cli = CommandLineArgs::parse();
|
||||
let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } =
|
||||
cli;
|
||||
let CommandLineArgs {
|
||||
file_name,
|
||||
threads,
|
||||
opt_level,
|
||||
emit_llvm_bc,
|
||||
emit_llvm_ir,
|
||||
triple,
|
||||
mcpu,
|
||||
target_features,
|
||||
} = cli;
|
||||
|
||||
Target::initialize_all(&InitializationConfig::default());
|
||||
|
||||
@ -346,11 +360,18 @@ fn main() {
|
||||
let resolver =
|
||||
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
|
||||
|
||||
let emit_llvm = |module: &Module<'_>, filename: &str| {
|
||||
if emit_llvm_bc {
|
||||
module.write_bitcode_to_path(Path::new(format!("{filename}.bc").as_str()));
|
||||
}
|
||||
if emit_llvm_ir {
|
||||
module.print_to_file(Path::new(format!("{filename}.ll").as_str())).unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
// Process IRRT
|
||||
let irrt = load_irrt(&context, resolver.as_ref());
|
||||
if emit_llvm {
|
||||
irrt.write_bitcode_to_path(Path::new("irrt.bc"));
|
||||
}
|
||||
emit_llvm(&irrt, "irrt");
|
||||
|
||||
// Process the Python script
|
||||
let parser_result = parser::parse_program(&program, file_name.into()).unwrap();
|
||||
@ -475,23 +496,19 @@ fn main() {
|
||||
let main = context
|
||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(&buffers[0], "main"))
|
||||
.unwrap();
|
||||
if emit_llvm {
|
||||
main.write_bitcode_to_path(Path::new("main.bc"));
|
||||
}
|
||||
emit_llvm(&main, "main");
|
||||
|
||||
for (idx, buffer) in buffers.iter().skip(1).enumerate() {
|
||||
for buffer in buffers.iter().skip(1) {
|
||||
let other = context
|
||||
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
|
||||
.unwrap();
|
||||
|
||||
if emit_llvm {
|
||||
other.write_bitcode_to_path(Path::new(&format!("module{idx}.bc")));
|
||||
}
|
||||
|
||||
main.link_in_module(other).unwrap();
|
||||
}
|
||||
emit_llvm(&main, "main.merged");
|
||||
|
||||
main.link_in_module(irrt).unwrap();
|
||||
emit_llvm(&main, "main.fat");
|
||||
|
||||
// Private all functions except "run"
|
||||
let mut function_iter = main.get_first_function();
|
||||
@ -502,6 +519,8 @@ fn main() {
|
||||
function_iter = func.get_next_function();
|
||||
}
|
||||
|
||||
emit_llvm(&main, "main.pre-opt");
|
||||
|
||||
// Optimize `main`
|
||||
let pass_options = PassBuilderOptions::create();
|
||||
pass_options.set_merge_functions(true);
|
||||
@ -511,6 +530,8 @@ fn main() {
|
||||
panic!("Failed to run optimization for module `main`: {}", err.to_string());
|
||||
}
|
||||
|
||||
emit_llvm(&main, "main.post-opt");
|
||||
|
||||
// Write output
|
||||
target_machine
|
||||
.write_to_file(&main, FileType::Object, Path::new("module.o"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user