Add basic support for embedded functions with new compiler.

This commit is contained in:
whitequark 2015-08-07 11:44:49 +03:00
parent b6e2613f77
commit 353f454a29
11 changed files with 278 additions and 339 deletions

View File

@ -1 +1,2 @@
from .module import Module, Source
from .embedding import Stitcher

View File

@ -93,3 +93,5 @@ class YieldFromT(ast.YieldFrom, commontyped):
# Novel typed nodes
class CoerceT(ast.expr, commontyped):
_fields = ('value',) # other_value deliberately not in _fields
class QuoteT(ast.expr, commontyped):
_fields = ('value',)

157
artiq/compiler/embedding.py Normal file
View File

@ -0,0 +1,157 @@
"""
The :class:`Stitcher` class allows to transparently combine compiled
Python code and Python code executed on the host system: it resolves
the references to the host objects and translates the functions
annotated as ``@kernel`` when they are referenced.
"""
import inspect
from pythonparser import ast, source, diagnostic, parse_buffer
from . import types, builtins, asttyped, prelude
from .transforms import ASTTypedRewriter, Inferencer
class ASTSynthesizer:
def __init__(self):
self.source = ""
self.source_buffer = source.Buffer(self.source, "<synthesized>")
def finalize(self):
self.source_buffer.source = self.source
return self.source_buffer
def _add(self, fragment):
range_from = len(self.source)
self.source += fragment
range_to = len(self.source)
return source.Range(self.source_buffer, range_from, range_to)
def quote(self, value):
"""Construct an AST fragment equal to `value`."""
if value in (None, True, False):
if node.value is True or node.value is False:
typ = builtins.TBool()
elif node.value is None:
typ = builtins.TNone()
return asttyped.NameConstantT(value=value, type=typ,
loc=self._add(repr(value)))
elif isinstance(value, (int, float)):
if isinstance(value, int):
typ = builtins.TInt()
elif isinstance(value, float):
typ = builtins.TFloat()
return asttyped.NumT(n=value, ctx=None, type=typ,
loc=self._add(repr(value)))
elif isinstance(value, list):
begin_loc = self._add("[")
elts = []
for index, elt in value:
elts.append(self.quote(elt))
if index < len(value) - 1:
self._add(", ")
end_loc = self._add("]")
return asttyped.ListT(elts=elts, ctx=None, type=types.TVar(),
begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc))
else:
raise "no"
# return asttyped.QuoteT(value=value, type=types.TVar())
def call(self, function_node, args, kwargs):
"""
Construct an AST fragment calling a function specified by
an AST node `function_node`, with given arguments.
"""
arg_nodes = []
kwarg_nodes = []
kwarg_locs = []
name_loc = self._add(function_node.name)
begin_loc = self._add("(")
for index, arg in enumerate(args):
arg_nodes.append(self.quote(arg))
if index < len(args) - 1:
self._add(", ")
if any(args) and any(kwargs):
self._add(", ")
for index, kw in enumerate(kwargs):
arg_loc = self._add(kw)
equals_loc = self._add("=")
kwarg_locs.append((arg_loc, equals_loc))
kwarg_nodes.append(self.quote(kwargs[kw]))
if index < len(kwargs) - 1:
self._add(", ")
end_loc = self._add(")")
return asttyped.CallT(
func=asttyped.NameT(id=function_node.name, ctx=None,
type=function_node.signature_type,
loc=name_loc),
args=arg_nodes,
keywords=[ast.keyword(arg=kw, value=value,
arg_loc=arg_loc, equals_loc=equals_loc,
loc=arg_loc.join(value.loc))
for kw, value, (arg_loc, equals_loc)
in zip(kwargs, kwarg_nodes, kwarg_locs)],
starargs=None, kwargs=None,
type=types.TVar(),
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
loc=name_loc.join(end_loc))
class StitchingASTTypedRewriter(ASTTypedRewriter):
pass
class Stitcher:
def __init__(self, engine=None):
if engine is None:
self.engine = diagnostic.Engine(all_errors_are_fatal=True)
else:
self.engine = engine
self.asttyped_rewriter = StitchingASTTypedRewriter(
engine=self.engine, globals=prelude.globals())
self.inferencer = Inferencer(engine=self.engine)
self.name = "stitched"
self.typedtree = None
self.globals = self.asttyped_rewriter.globals
self.rpc_map = {}
def _iterate(self):
# Iterate inference to fixed point.
self.inference_finished = False
while not self.inference_finished:
self.inference_finished = True
self.inferencer.visit(self.typedtree)
def _parse_embedded_function(self, function):
if not hasattr(function, "artiq_embedded"):
raise ValueError("{} is not an embedded function".format(repr(function)))
# Extract function source.
embedded_function = function.artiq_embedded.function
source_code = inspect.getsource(embedded_function)
filename = embedded_function.__code__.co_filename
first_line = embedded_function.__code__.co_firstlineno
# Parse.
source_buffer = source.Buffer(source_code, filename, first_line)
parsetree, comments = parse_buffer(source_buffer, engine=self.engine)
# Rewrite into typed form.
typedtree = self.asttyped_rewriter.visit(parsetree)
return typedtree, typedtree.body[0]
def stitch_call(self, function, args, kwargs):
self.typedtree, function_node = self._parse_embedded_function(function)
# We synthesize fake source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = ASTSynthesizer()
call_node = synthesizer.call(function_node, args, kwargs)
synthesizer.finalize()
self.typedtree.body.append(call_node)
self._iterate()

View File

@ -190,10 +190,15 @@ class ASTTypedRewriter(algorithm.Transformer):
self.globals = None
self.env_stack = [globals]
def _find_name(self, name, loc):
def _try_find_name(self, name):
for typing_env in reversed(self.env_stack):
if name in typing_env:
return typing_env[name]
def _find_name(self, name, loc):
typ = self._try_find_name(name)
if typ is not None:
return typ
diag = diagnostic.Diagnostic("fatal",
"name '{name}' is not bound to anything", {"name":name}, loc)
self.engine.process(diag)

View File

@ -3,9 +3,7 @@ import logging
from enum import Enum
from fractions import Fraction
from artiq.coredevice import runtime_exceptions
from artiq.language import core as core_language
from artiq.coredevice.rpc_wrapper import RPCWrapper
logger = logging.getLogger(__name__)
@ -198,35 +196,28 @@ class CommGeneric:
else:
r.append(self._receive_rpc_value(type_tag))
def _serve_rpc(self, rpc_wrapper, rpc_map, user_exception_map):
def _serve_rpc(self, rpc_map):
rpc_num = struct.unpack(">l", self.read(4))[0]
args = self._receive_rpc_values()
logger.debug("rpc service: %d %r", rpc_num, args)
eid, r = rpc_wrapper.run_rpc(
user_exception_map, rpc_map[rpc_num], args)
eid, r = rpc_wrapper.run_rpc(rpc_map[rpc_num], args)
self._write_header(9+2*4, _H2DMsgType.RPC_REPLY)
self.write(struct.pack(">ll", eid, r))
logger.debug("rpc service: %d %r == %r (eid %d)", rpc_num, args,
r, eid)
def _serve_exception(self, rpc_wrapper, user_exception_map):
def _serve_exception(self):
eid, p0, p1, p2 = struct.unpack(">lqqq", self.read(4+3*8))
rpc_wrapper.filter_rpc_exception(eid)
if eid < core_language.first_user_eid:
exception = runtime_exceptions.exception_map[eid]
raise exception(self.core, p0, p1, p2)
else:
exception = user_exception_map[eid]
raise exception
def serve(self, rpc_map, user_exception_map):
rpc_wrapper = RPCWrapper()
def serve(self, rpc_map):
while True:
_, ty = self._read_header()
if ty == _D2HMsgType.RPC_REQUEST:
self._serve_rpc(rpc_wrapper, rpc_map, user_exception_map)
self._serve_rpc(rpc_map)
elif ty == _D2HMsgType.KERNEL_EXCEPTION:
self._serve_exception(rpc_wrapper, user_exception_map)
self._serve_exception()
elif ty == _D2HMsgType.KERNEL_FINISHED:
return
else:

View File

@ -1,50 +1,20 @@
import os
import os, sys, tempfile
from pythonparser import diagnostic
from artiq.language.core import *
from artiq.language.units import ns
from artiq.transforms.inline import inline
from artiq.transforms.quantize_time import quantize_time
from artiq.transforms.remove_inter_assigns import remove_inter_assigns
from artiq.transforms.fold_constants import fold_constants
from artiq.transforms.remove_dead_code import remove_dead_code
from artiq.transforms.unroll_loops import unroll_loops
from artiq.transforms.interleave import interleave
from artiq.transforms.lower_time import lower_time
from artiq.transforms.unparse import unparse
from artiq.compiler import Stitcher, Module
from artiq.compiler.targets import OR1KTarget
from artiq.coredevice.runtime import Runtime
from artiq.py2llvm import get_runtime_binary
# Import for side effects (creating the exception classes).
from artiq.coredevice import exceptions
def _announce_unparse(label, node):
print("*** Unparsing: "+label)
print(unparse(node))
def _make_debug_unparse(final):
try:
env = os.environ["ARTIQ_UNPARSE"]
except KeyError:
env = ""
selected_labels = set(env.split())
if "all" in selected_labels:
return _announce_unparse
else:
if "final" in selected_labels:
selected_labels.add(final)
def _filtered_unparse(label, node):
if label in selected_labels:
_announce_unparse(label, node)
return _filtered_unparse
def _no_debug_unparse(label, node):
class CompileError(Exception):
pass
class Core:
def __init__(self, dmgr, ref_period=8*ns, external_clock=False):
self.comm = dmgr.get("comm")
@ -54,70 +24,41 @@ class Core:
self.first_run = True
self.core = self
self.comm.core = self
self.runtime = Runtime()
def transform_stack(self, func_def, rpc_map, exception_map,
debug_unparse=_no_debug_unparse):
remove_inter_assigns(func_def)
debug_unparse("remove_inter_assigns_1", func_def)
def compile(self, function, args, kwargs, with_attr_writeback=True):
try:
engine = diagnostic.Engine(all_errors_are_fatal=True)
quantize_time(func_def, self.ref_period)
debug_unparse("quantize_time", func_def)
stitcher = Stitcher(engine=engine)
stitcher.stitch_call(function, args, kwargs)
fold_constants(func_def)
debug_unparse("fold_constants_1", func_def)
module = Module(stitcher)
library = OR1KTarget().compile_and_link([module])
unroll_loops(func_def, 500)
debug_unparse("unroll_loops", func_def)
return library, stitcher.rpc_map
except diagnostic.Error as error:
print("\n".join(error.diagnostic.render(colored=True)), file=sys.stderr)
raise CompileError() from error
interleave(func_def)
debug_unparse("interleave", func_def)
lower_time(func_def)
debug_unparse("lower_time", func_def)
remove_inter_assigns(func_def)
debug_unparse("remove_inter_assigns_2", func_def)
fold_constants(func_def)
debug_unparse("fold_constants_2", func_def)
remove_dead_code(func_def)
debug_unparse("remove_dead_code_1", func_def)
remove_inter_assigns(func_def)
debug_unparse("remove_inter_assigns_3", func_def)
fold_constants(func_def)
debug_unparse("fold_constants_3", func_def)
remove_dead_code(func_def)
debug_unparse("remove_dead_code_2", func_def)
def compile(self, k_function, k_args, k_kwargs, with_attr_writeback=True):
debug_unparse = _make_debug_unparse("remove_dead_code_2")
func_def, rpc_map, exception_map = inline(
self, k_function, k_args, k_kwargs, with_attr_writeback)
debug_unparse("inline", func_def)
self.transform_stack(func_def, rpc_map, exception_map, debug_unparse)
binary = get_runtime_binary(self.runtime, func_def)
return binary, rpc_map, exception_map
def run(self, k_function, k_args, k_kwargs):
def run(self, function, args, kwargs):
if self.first_run:
self.comm.check_ident()
self.comm.switch_clock(self.external_clock)
binary, rpc_map, exception_map = self.compile(
k_function, k_args, k_kwargs)
self.comm.load(binary)
self.comm.run(k_function.__name__)
self.comm.serve(rpc_map, exception_map)
self.first_run = False
kernel_library, rpc_map = self.compile(function, args, kwargs)
try:
self.comm.load(kernel_library)
except Exception as error:
shlib_temp = tempfile.NamedTemporaryFile(suffix=".so", delete=False)
shlib_temp.write(kernel_library)
shlib_temp.close()
raise RuntimeError("shared library dumped to {}".format(shlib_temp.name)) from error
self.comm.run()
self.comm.serve(rpc_map)
@kernel
def get_rtio_counter_mu(self):
return syscall("rtio_get_counter")

View File

@ -0,0 +1,33 @@
from artiq.language.core import ARTIQException
class InternalError(ARTIQException):
"""Raised when the runtime encounters an internal error condition."""
class RTIOUnderflow(ARTIQException):
"""Raised when the CPU fails to submit a RTIO event early enough
(with respect to the event's timestamp).
The offending event is discarded and the RTIO core keeps operating.
"""
class RTIOSequenceError(ARTIQException):
"""Raised when an event is submitted on a given channel with a timestamp
not larger than the previous one.
The offending event is discarded and the RTIO core keeps operating.
"""
class RTIOOverflow(ARTIQException):
"""Raised when at least one event could not be registered into the RTIO
input FIFO because it was full (CPU not reading fast enough).
This does not interrupt operations further than cancelling the current
read attempt and discarding some events. Reading can be reattempted after
the exception is caught, and events will be partially retrieved.
"""
class DDSBatchError(ARTIQException):
"""Raised when attempting to start a DDS batch while already in a batch,
or when too many commands are batched.
"""

View File

@ -1,40 +0,0 @@
from artiq.coredevice.runtime_exceptions import exception_map, _RPCException
def _lookup_exception(d, e):
for eid, exception in d.items():
if isinstance(e, exception):
return eid
return 0
class RPCWrapper:
def __init__(self):
self.last_exception = None
def run_rpc(self, user_exception_map, fn, args):
eid = 0
r = None
try:
r = fn(*args)
except Exception as e:
eid = _lookup_exception(user_exception_map, e)
if not eid:
eid = _lookup_exception(exception_map, e)
if eid:
self.last_exception = None
else:
self.last_exception = e
eid = _RPCException.eid
if r is None:
r = 0
else:
r = int(r)
return eid, r
def filter_rpc_exception(self, eid):
if eid == _RPCException.eid:
raise self.last_exception

View File

@ -3,14 +3,8 @@ import os
import llvmlite_or1k.ir as ll
import llvmlite_or1k.binding as llvm
from artiq.py2llvm import base_types, fractions, lists
from artiq.language import units
llvm.initialize()
llvm.initialize_all_targets()
llvm.initialize_all_asmprinters()
_syscalls = {
"now_init": "n:I",
"now_save": "I:n",
@ -97,27 +91,6 @@ class LinkInterface:
self.syscalls[func_name] = ll.Function(
llvm_module, func_type, "__syscall_" + func_name)
# exception handling
func_type = ll.FunctionType(ll.IntType(32),
[ll.PointerType(ll.IntType(8))])
self.eh_setjmp = ll.Function(llvm_module, func_type,
"__eh_setjmp")
self.eh_setjmp.attributes.add("nounwind")
self.eh_setjmp.attributes.add("returns_twice")
func_type = ll.FunctionType(ll.PointerType(ll.IntType(8)), [])
self.eh_push = ll.Function(llvm_module, func_type, "__eh_push")
func_type = ll.FunctionType(ll.VoidType(), [ll.IntType(32)])
self.eh_pop = ll.Function(llvm_module, func_type, "__eh_pop")
func_type = ll.FunctionType(ll.IntType(32), [])
self.eh_getid = ll.Function(llvm_module, func_type, "__eh_getid")
func_type = ll.FunctionType(ll.VoidType(), [ll.IntType(32)])
self.eh_raise = ll.Function(llvm_module, func_type, "__eh_raise")
self.eh_raise.attributes.add("noreturn")
def _build_rpc(self, args, builder):
r = base_types.VInt()
if builder is not None:
@ -159,54 +132,3 @@ class LinkInterface:
return self._build_rpc(args, builder)
else:
return self._build_regular_syscall(syscall_name, args, builder)
def build_catch(self, builder):
jmpbuf = builder.call(self.eh_push, [])
exception_occured = builder.call(self.eh_setjmp, [jmpbuf])
return builder.icmp_signed("!=",
exception_occured,
ll.Constant(ll.IntType(32), 0))
def build_pop(self, builder, levels):
builder.call(self.eh_pop, [ll.Constant(ll.IntType(32), levels)])
def build_getid(self, builder):
return builder.call(self.eh_getid, [])
def build_raise(self, builder, eid):
builder.call(self.eh_raise, [eid])
def _debug_dump_obj(obj):
try:
env = os.environ["ARTIQ_DUMP_OBJECT"]
except KeyError:
return
for i in range(1000):
filename = "{}_{:03d}.elf".format(env, i)
try:
f = open(filename, "xb")
except FileExistsError:
pass
else:
f.write(obj)
f.close()
return
raise IOError
class Runtime(LinkInterface):
def __init__(self):
self.cpu_type = "or1k"
# allow 1ms for all initial DDS programming
self.warmup_time = 1*units.ms
def emit_object(self):
tm = llvm.Target.from_triple(self.cpu_type).create_target_machine()
obj = tm.emit_object(self.module.llvm_module_ref)
_debug_dump_obj(obj)
return obj
def __repr__(self):
return "<Runtime {}>".format(self.cpu_type)

View File

@ -1,69 +0,0 @@
import inspect
from artiq.language.core import RuntimeException
# Must be kept in sync with soc/runtime/exceptions.h
class InternalError(RuntimeException):
"""Raised when the runtime encounters an internal error condition."""
eid = 1
class _RPCException(RuntimeException):
eid = 2
class RTIOUnderflow(RuntimeException):
"""Raised when the CPU fails to submit a RTIO event early enough
(with respect to the event's timestamp).
The offending event is discarded and the RTIO core keeps operating.
"""
eid = 3
def __str__(self):
return "at {} on channel {}, violation {}".format(
self.p0*self.core.ref_period,
self.p1,
(self.p2 - self.p0)*self.core.ref_period)
class RTIOSequenceError(RuntimeException):
"""Raised when an event is submitted on a given channel with a timestamp
not larger than the previous one.
The offending event is discarded and the RTIO core keeps operating.
"""
eid = 4
def __str__(self):
return "at {} on channel {}".format(self.p0*self.core.ref_period,
self.p1)
class RTIOOverflow(RuntimeException):
"""Raised when at least one event could not be registered into the RTIO
input FIFO because it was full (CPU not reading fast enough).
This does not interrupt operations further than cancelling the current
read attempt and discarding some events. Reading can be reattempted after
the exception is caught, and events will be partially retrieved.
"""
eid = 5
def __str__(self):
return "on channel {}".format(self.p0)
class DDSBatchError(RuntimeException):
"""Raised when attempting to start a DDS batch while already in a batch,
or when too many commands are batched.
"""
eid = 6
exception_map = {e.eid: e for e in globals().values()
if inspect.isclass(e)
and issubclass(e, RuntimeException)
and hasattr(e, "eid")}

View File

@ -8,7 +8,7 @@ from functools import wraps
__all__ = ["int64", "round64", "kernel", "portable",
"set_time_manager", "set_syscall_manager", "set_watchdog_factory",
"RuntimeException", "EncodedException"]
"ARTIQException"]
# global namespace for kernels
kernel_globals = ("sequential", "parallel",
@ -77,7 +77,7 @@ def round64(x):
return int64(round(x))
_KernelFunctionInfo = namedtuple("_KernelFunctionInfo", "core_name k_function")
_ARTIQEmbeddedInfo = namedtuple("_ARTIQEmbeddedInfo", "core_name function")
def kernel(arg):
@ -100,25 +100,19 @@ def kernel(arg):
specifies the name of the attribute to use as core device driver.
"""
if isinstance(arg, str):
def real_decorator(k_function):
@wraps(k_function)
def run_on_core(exp, *k_args, **k_kwargs):
return getattr(exp, arg).run(k_function,
((exp,) + k_args), k_kwargs)
run_on_core.k_function_info = _KernelFunctionInfo(
core_name=arg, k_function=k_function)
def inner_decorator(function):
@wraps(function)
def run_on_core(self, *k_args, **k_kwargs):
return getattr(self, arg).run(function, ((self,) + k_args), k_kwargs)
run_on_core.artiq_embedded = _ARTIQEmbeddedInfo(
core_name=arg, function=function)
return run_on_core
return real_decorator
return inner_decorator
else:
@wraps(arg)
def run_on_core(exp, *k_args, **k_kwargs):
return exp.core.run(arg, ((exp,) + k_args), k_kwargs)
run_on_core.k_function_info = _KernelFunctionInfo(
core_name="core", k_function=arg)
return run_on_core
return kernel("core")(arg)
def portable(f):
def portable(function):
"""This decorator marks a function for execution on the same device as its
caller.
@ -127,8 +121,8 @@ def portable(f):
core device). A decorated function called from a kernel will be executed
on the core device (no RPC).
"""
f.k_function_info = _KernelFunctionInfo(core_name="", k_function=f)
return f
function.artiq_embedded = _ARTIQEmbeddedInfo(core_name="", function=function)
return function
class _DummyTimeManager:
@ -280,32 +274,34 @@ def watchdog(timeout):
return _watchdog_factory(timeout)
_encoded_exceptions = dict()
class ARTIQException(Exception):
"""Base class for exceptions raised or passed through the core device."""
# Try and create an instance of the specific class, if one exists.
def __new__(cls, name, message, params):
def find_subclass(cls):
if cls.__name__ == name:
return cls
else:
for subclass in cls.__subclasses__():
cls = find_subclass(subclass)
if cls is not None:
return cls
def EncodedException(eid):
"""Represents exceptions on the core device, which are identified
by a single number."""
try:
return _encoded_exceptions[eid]
except KeyError:
class EncodedException(Exception):
def __init__(self):
Exception.__init__(self, eid)
_encoded_exceptions[eid] = EncodedException
return EncodedException
more_specific_cls = find_subclass(cls)
if more_specific_cls is None:
more_specific_cls = cls
exn = Exception.__new__(more_specific_cls)
exn.__init__(name, message, params)
return exn
class RuntimeException(Exception):
"""Base class for all exceptions used by the device runtime.
Those exceptions are defined in ``artiq.coredevice.runtime_exceptions``.
"""
def __init__(self, core, p0, p1, p2):
Exception.__init__(self)
self.core = core
self.p0 = p0
self.p1 = p1
self.p2 = p2
def __init__(self, name, message, params):
Exception.__init__(self, name, message, *params)
self.name, self.message, self.params = name, message, params
first_user_eid = 1024
def __str__(self):
if type(self).__name__ == self.name:
return self.message.format(*self.params)
else:
return "({}) {}".format(self.name, self.message.format(*self.params))