forked from M-Labs/artiq
1
0
Fork 0
This commit is contained in:
Sebastien Bourdeauducq 2014-09-05 12:03:22 +08:00
parent 7e9df82e37
commit 4915b4b5aa
43 changed files with 3179 additions and 2915 deletions

View File

@ -1,81 +1,86 @@
import ast, operator import ast
import operator
from artiq.compiler.tools import * from artiq.compiler.tools import *
from artiq.language.core import int64, round64 from artiq.language.core import int64, round64
_ast_unops = { _ast_unops = {
ast.Invert: operator.inv, ast.Invert: operator.inv,
ast.Not: operator.not_, ast.Not: operator.not_,
ast.UAdd: operator.pos, ast.UAdd: operator.pos,
ast.USub: operator.neg ast.USub: operator.neg
} }
_ast_binops = { _ast_binops = {
ast.Add: operator.add, ast.Add: operator.add,
ast.Sub: operator.sub, ast.Sub: operator.sub,
ast.Mult: operator.mul, ast.Mult: operator.mul,
ast.Div: operator.truediv, ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv, ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod, ast.Mod: operator.mod,
ast.Pow: operator.pow, ast.Pow: operator.pow,
ast.LShift: operator.lshift, ast.LShift: operator.lshift,
ast.RShift: operator.rshift, ast.RShift: operator.rshift,
ast.BitOr: operator.or_, ast.BitOr: operator.or_,
ast.BitXor: operator.xor, ast.BitXor: operator.xor,
ast.BitAnd: operator.and_ ast.BitAnd: operator.and_
} }
class _ConstantFolder(ast.NodeTransformer): class _ConstantFolder(ast.NodeTransformer):
def visit_UnaryOp(self, node): def visit_UnaryOp(self, node):
self.generic_visit(node) self.generic_visit(node)
try: try:
operand = eval_constant(node.operand) operand = eval_constant(node.operand)
except NotConstant: except NotConstant:
return node return node
try: try:
op = _ast_unops[type(node.op)] op = _ast_unops[type(node.op)]
except KeyError: except KeyError:
return node return node
try: try:
result = value_to_ast(op(operand)) result = value_to_ast(op(operand))
except: except:
return node return node
return ast.copy_location(result, node) return ast.copy_location(result, node)
def visit_BinOp(self, node): def visit_BinOp(self, node):
self.generic_visit(node) self.generic_visit(node)
try: try:
left, right = eval_constant(node.left), eval_constant(node.right) left, right = eval_constant(node.left), eval_constant(node.right)
except NotConstant: except NotConstant:
return node return node
try: try:
op = _ast_binops[type(node.op)] op = _ast_binops[type(node.op)]
except KeyError: except KeyError:
return node return node
try: try:
result = value_to_ast(op(left, right)) result = value_to_ast(op(left, right))
except: except:
return node return node
return ast.copy_location(result, node) return ast.copy_location(result, node)
def visit_Call(self, node):
self.generic_visit(node)
fn = node.func.id
constant_ops = {
"int": int,
"int64": int64,
"round": round,
"round64": round64
}
if fn in constant_ops:
try:
arg = eval_constant(node.args[0])
except NotConstant:
return node
result = value_to_ast(constant_ops[fn](arg))
return ast.copy_location(result, node)
else:
return node
def visit_Call(self, node):
self.generic_visit(node)
fn = node.func.id
constant_ops = {
"int": int,
"int64": int64,
"round": round,
"round64": round64
}
if fn in constant_ops:
try:
arg = eval_constant(node.args[0])
except NotConstant:
return node
result = value_to_ast(constant_ops[fn](arg))
return ast.copy_location(result, node)
else:
return node
def fold_constants(node): def fold_constants(node):
_ConstantFolder().visit(node) _ConstantFolder().visit(node)

View File

@ -1,229 +1,253 @@
from collections import namedtuple, defaultdict from collections import namedtuple, defaultdict
from fractions import Fraction from fractions import Fraction
import inspect, textwrap, ast import inspect
import textwrap
import ast
from artiq.compiler.tools import eval_ast, value_to_ast from artiq.compiler.tools import eval_ast, value_to_ast
from artiq.language import core as core_language from artiq.language import core as core_language
from artiq.language import units from artiq.language import units
_UserVariable = namedtuple("_UserVariable", "name") _UserVariable = namedtuple("_UserVariable", "name")
def _is_in_attr_list(obj, attr, al): def _is_in_attr_list(obj, attr, al):
if not hasattr(obj, al): if not hasattr(obj, al):
return False return False
return attr in getattr(obj, al).split() return attr in getattr(obj, al).split()
class _ReferenceManager: class _ReferenceManager:
def __init__(self): def __init__(self):
# (id(obj), funcname, local) -> _UserVariable(name) / ast / constant_object # (id(obj), funcname, local)
# local is None for kernel attributes # -> _UserVariable(name) / ast / constant_object
self.to_inlined = dict() # local is None for kernel attributes
# inlined_name -> use_count self.to_inlined = dict()
self.use_count = dict() # inlined_name -> use_count
self.rpc_map = defaultdict(lambda: len(self.rpc_map)) self.use_count = dict()
self.kernel_attr_init = [] self.rpc_map = defaultdict(lambda: len(self.rpc_map))
self.kernel_attr_init = []
# reserved names # reserved names
for kg in core_language.kernel_globals: for kg in core_language.kernel_globals:
self.use_count[kg] = 1 self.use_count[kg] = 1
for name in "int", "round", "int64", "round64", \ for name in ("int", "round", "int64", "round64",
"range", "Fraction", "Quantity", \ "range", "Fraction", "Quantity",
"s_unit", "Hz_unit", "microcycle_unit": "s_unit", "Hz_unit", "microcycle_unit"):
self.use_count[name] = 1 self.use_count[name] = 1
def new_name(self, base_name): def new_name(self, base_name):
if base_name[-1].isdigit(): if base_name[-1].isdigit():
base_name += "_" base_name += "_"
if base_name in self.use_count: if base_name in self.use_count:
r = base_name + str(self.use_count[base_name]) r = base_name + str(self.use_count[base_name])
self.use_count[base_name] += 1 self.use_count[base_name] += 1
return r return r
else: else:
self.use_count[base_name] = 1 self.use_count[base_name] = 1
return base_name return base_name
def get(self, obj, funcname, ref): def get(self, obj, funcname, ref):
store = isinstance(ref.ctx, ast.Store) store = isinstance(ref.ctx, ast.Store)
if isinstance(ref, ast.Name): if isinstance(ref, ast.Name):
key = (id(obj), funcname, ref.id) key = (id(obj), funcname, ref.id)
try: try:
return self.to_inlined[key] return self.to_inlined[key]
except KeyError: except KeyError:
if store: if store:
ival = _UserVariable(self.new_name(ref.id)) ival = _UserVariable(self.new_name(ref.id))
self.to_inlined[key] = ival self.to_inlined[key] = ival
return ival return ival
if isinstance(ref, ast.Attribute) and isinstance(ref.value, ast.Name): if isinstance(ref, ast.Attribute) and isinstance(ref.value, ast.Name):
try: try:
value = self.to_inlined[(id(obj), funcname, ref.value.id)] value = self.to_inlined[(id(obj), funcname, ref.value.id)]
except KeyError: except KeyError:
pass pass
else: else:
if _is_in_attr_list(value, ref.attr, "kernel_attr_ro"): if _is_in_attr_list(value, ref.attr, "kernel_attr_ro"):
if store: if store:
raise TypeError("Attempted to assign to read-only kernel attribute") raise TypeError(
return getattr(value, ref.attr) "Attempted to assign to read-only"
if _is_in_attr_list(value, ref.attr, "kernel_attr"): " kernel attribute")
key = (id(value), ref.attr, None) return getattr(value, ref.attr)
try: if _is_in_attr_list(value, ref.attr, "kernel_attr"):
ival = self.to_inlined[key] key = (id(value), ref.attr, None)
assert(isinstance(ival, _UserVariable)) try:
except KeyError: ival = self.to_inlined[key]
iname = self.new_name(ref.attr) assert(isinstance(ival, _UserVariable))
ival = _UserVariable(iname) except KeyError:
self.to_inlined[key] = ival iname = self.new_name(ref.attr)
a = value_to_ast(getattr(value, ref.attr)) ival = _UserVariable(iname)
if a is None: self.to_inlined[key] = ival
raise NotImplementedError("Cannot represent initial value of kernel attribute") a = value_to_ast(getattr(value, ref.attr))
self.kernel_attr_init.append(ast.Assign( if a is None:
[ast.Name(iname, ast.Store())], a)) raise NotImplementedError(
return ival "Cannot represent initial value"
" of kernel attribute")
self.kernel_attr_init.append(ast.Assign(
[ast.Name(iname, ast.Store())], a))
return ival
if not store: if not store:
evd = self.get_constants(obj, funcname) evd = self.get_constants(obj, funcname)
evd.update(inspect.getmodule(obj).__dict__) evd.update(inspect.getmodule(obj).__dict__)
return eval_ast(ref, evd) return eval_ast(ref, evd)
else: else:
raise KeyError raise KeyError
def set(self, obj, funcname, name, value): def set(self, obj, funcname, name, value):
self.to_inlined[(id(obj), funcname, name)] = value self.to_inlined[(id(obj), funcname, name)] = value
def get_constants(self, r_obj, r_funcname):
return {
local: v for (objid, funcname, local), v
in self.to_inlined.items()
if objid == id(r_obj)
and funcname == r_funcname
and not isinstance(v, (_UserVariable, ast.AST))}
def get_constants(self, r_obj, r_funcname):
return {local: v for (objid, funcname, local), v
in self.to_inlined.items()
if objid == id(r_obj)
and funcname == r_funcname
and not isinstance(v, (_UserVariable, ast.AST))}
_embeddable_calls = { _embeddable_calls = {
core_language.delay, core_language.at, core_language.now, core_language.delay, core_language.at, core_language.now,
core_language.syscall, core_language.syscall,
range, int, round, core_language.int64, core_language.round64, range, int, round, core_language.int64, core_language.round64,
Fraction, units.Quantity Fraction, units.Quantity
} }
class _ReferenceReplacer(ast.NodeTransformer): class _ReferenceReplacer(ast.NodeTransformer):
def __init__(self, core, rm, obj, funcname): def __init__(self, core, rm, obj, funcname):
self.core = core self.core = core
self.rm = rm self.rm = rm
self.obj = obj self.obj = obj
self.funcname = funcname self.funcname = funcname
def visit_ref(self, node): def visit_ref(self, node):
store = isinstance(node.ctx, ast.Store) store = isinstance(node.ctx, ast.Store)
ival = self.rm.get(self.obj, self.funcname, node) ival = self.rm.get(self.obj, self.funcname, node)
if isinstance(ival, _UserVariable): if isinstance(ival, _UserVariable):
newnode = ast.Name(ival.name, node.ctx) newnode = ast.Name(ival.name, node.ctx)
elif isinstance(ival, ast.AST): elif isinstance(ival, ast.AST):
assert(not store) assert(not store)
newnode = ival newnode = ival
else: else:
if store: if store:
raise NotImplementedError("Cannot turn object into user variable") raise NotImplementedError(
else: "Cannot turn object into user variable")
newnode = value_to_ast(ival) else:
if newnode is None: newnode = value_to_ast(ival)
raise NotImplementedError("Cannot represent inlined value") if newnode is None:
return ast.copy_location(newnode, node) raise NotImplementedError(
"Cannot represent inlined value")
return ast.copy_location(newnode, node)
visit_Name = visit_ref visit_Name = visit_ref
visit_Attribute = visit_ref visit_Attribute = visit_ref
visit_Subscript = visit_ref visit_Subscript = visit_ref
def visit_Call(self, node): def visit_Call(self, node):
func = self.rm.get(self.obj, self.funcname, node.func) func = self.rm.get(self.obj, self.funcname, node.func)
new_args = [self.visit(arg) for arg in node.args] new_args = [self.visit(arg) for arg in node.args]
if func in _embeddable_calls: if func in _embeddable_calls:
new_func = ast.Name(func.__name__, ast.Load()) new_func = ast.Name(func.__name__, ast.Load())
return ast.copy_location( return ast.copy_location(
ast.Call(func=new_func, args=new_args, ast.Call(func=new_func, args=new_args,
keywords=[], starargs=None, kwargs=None), keywords=[], starargs=None, kwargs=None),
node) node)
elif hasattr(func, "k_function_info") and getattr(func.__self__, func.k_function_info.core_name) is self.core: elif (hasattr(func, "k_function_info")
args = [func.__self__] + new_args and getattr(func.__self__, func.k_function_info.core_name)
inlined, _ = inline(self.core, func.k_function_info.k_function, args, dict(), self.rm) is self.core):
return inlined.body args = [func.__self__] + new_args
else: inlined, _ = inline(self.core, func.k_function_info.k_function,
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])] args, dict(), self.rm)
args += new_args return inlined.body
return ast.copy_location( else:
ast.Call(func=ast.Name("syscall", ast.Load()), args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
args=args, keywords=[], starargs=None, kwargs=None), args += new_args
node) return ast.copy_location(
ast.Call(func=ast.Name("syscall", ast.Load()),
args=args, keywords=[], starargs=None, kwargs=None),
node)
def visit_Expr(self, node): def visit_Expr(self, node):
if isinstance(node.value, ast.Call): if isinstance(node.value, ast.Call):
r = self.visit_Call(node.value) r = self.visit_Call(node.value)
if isinstance(r, list): if isinstance(r, list):
return r return r
else: else:
node.value = r node.value = r
return node return node
else: else:
self.generic_visit(node) self.generic_visit(node)
return node return node
def visit_FunctionDef(self, node):
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
kw_defaults=[], kwarg=None, defaults=[])
node.decorator_list = []
self.generic_visit(node)
return node
def visit_FunctionDef(self, node):
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[])
node.decorator_list = []
self.generic_visit(node)
return node
class _ListReadOnlyParams(ast.NodeVisitor): class _ListReadOnlyParams(ast.NodeVisitor):
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
if hasattr(self, "read_only_params"): if hasattr(self, "read_only_params"):
raise ValueError("More than one function definition") raise ValueError("More than one function definition")
self.read_only_params = {arg.arg for arg in node.args.args} self.read_only_params = {arg.arg for arg in node.args.args}
self.generic_visit(node) self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
try:
self.read_only_params.remove(node.id)
except KeyError:
pass
def visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
try:
self.read_only_params.remove(node.id)
except KeyError:
pass
def _list_read_only_params(funcdef): def _list_read_only_params(funcdef):
lrp = _ListReadOnlyParams() lrp = _ListReadOnlyParams()
lrp.visit(funcdef) lrp.visit(funcdef)
return lrp.read_only_params return lrp.read_only_params
def _initialize_function_params(funcdef, k_args, k_kwargs, rm): def _initialize_function_params(funcdef, k_args, k_kwargs, rm):
obj = k_args[0] obj = k_args[0]
funcname = funcdef.name funcname = funcdef.name
param_init = [] param_init = []
rop = _list_read_only_params(funcdef) rop = _list_read_only_params(funcdef)
for arg_ast, arg_value in zip(funcdef.args.args, k_args): for arg_ast, arg_value in zip(funcdef.args.args, k_args):
arg_name = arg_ast.arg arg_name = arg_ast.arg
if arg_name in rop: if arg_name in rop:
rm.set(obj, funcname, arg_name, arg_value) rm.set(obj, funcname, arg_name, arg_value)
else: else:
target = rm.get(obj, funcname, ast.Name(arg_name, ast.Store())) target = rm.get(obj, funcname, ast.Name(arg_name, ast.Store()))
value = value_to_ast(arg_value) value = value_to_ast(arg_value)
param_init.append(ast.Assign(targets=[target], value=value)) param_init.append(ast.Assign(targets=[target], value=value))
return param_init return param_init
def inline(core, k_function, k_args, k_kwargs, rm=None): def inline(core, k_function, k_args, k_kwargs, rm=None):
init_kernel_attr = rm is None init_kernel_attr = rm is None
if rm is None: if rm is None:
rm = _ReferenceManager() rm = _ReferenceManager()
funcdef = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0] funcdef = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0]
param_init = _initialize_function_params(funcdef, k_args, k_kwargs, rm) param_init = _initialize_function_params(funcdef, k_args, k_kwargs, rm)
obj = k_args[0] obj = k_args[0]
funcname = funcdef.name funcname = funcdef.name
rr = _ReferenceReplacer(core, rm, obj, funcname) rr = _ReferenceReplacer(core, rm, obj, funcname)
rr.visit(funcdef) rr.visit(funcdef)
funcdef.body[0:0] = param_init funcdef.body[0:0] = param_init
if init_kernel_attr: if init_kernel_attr:
funcdef.body[0:0] = rm.kernel_attr_init funcdef.body[0:0] = rm.kernel_attr_init
r_rpc_map = dict((rpc_num, rpc_fun) for rpc_fun, rpc_num in rm.rpc_map.items()) r_rpc_map = dict((rpc_num, rpc_fun)
return funcdef, r_rpc_map for rpc_fun, rpc_num in rm.rpc_map.items())
return funcdef, r_rpc_map

View File

@ -1,105 +1,113 @@
import ast, types import ast
import types
from artiq.compiler.tools import * from artiq.compiler.tools import *
# -1 statement duration could not be pre-determined # -1 statement duration could not be pre-determined
# 0 statement has no effect on timeline # 0 statement has no effect on timeline
# >0 statement is a static delay that advances the timeline # >0 statement is a static delay that advances the timeline
# by the given amount (in microcycles) # by the given amount (in microcycles)
def _get_duration(stmt): def _get_duration(stmt):
if isinstance(stmt, (ast.Expr, ast.Assign)): if isinstance(stmt, (ast.Expr, ast.Assign)):
return _get_duration(stmt.value) return _get_duration(stmt.value)
elif isinstance(stmt, ast.If): elif isinstance(stmt, ast.If):
if all(_get_duration(s) == 0 for s in stmt.body) and all(_get_duration(s) == 0 for s in stmt.orelse): if (all(_get_duration(s) == 0 for s in stmt.body)
return 0 and all(_get_duration(s) == 0 for s in stmt.orelse)):
else: return 0
return -1 else:
elif isinstance(stmt, ast.Call) and isinstance(stmt.func, ast.Name): return -1
name = stmt.func.id elif isinstance(stmt, ast.Call) and isinstance(stmt.func, ast.Name):
if name == "delay": name = stmt.func.id
try: if name == "delay":
da = eval_constant(stmt.args[0]) try:
except NotConstant: da = eval_constant(stmt.args[0])
da = -1 except NotConstant:
return da da = -1
else: return da
return 0 else:
else: return 0
return 0 else:
return 0
def _interleave_timelines(timelines): def _interleave_timelines(timelines):
r = [] r = []
current_stmts = [] current_stmts = []
for stmts in timelines: for stmts in timelines:
it = iter(stmts) it = iter(stmts)
try: try:
stmt = next(it) stmt = next(it)
except StopIteration: except StopIteration:
pass pass
else: else:
current_stmts.append(types.SimpleNamespace(delay=_get_duration(stmt), stmt=stmt, it=it)) current_stmts.append(types.SimpleNamespace(
delay=_get_duration(stmt), stmt=stmt, it=it))
while current_stmts: while current_stmts:
dt = min(stmt.delay for stmt in current_stmts) dt = min(stmt.delay for stmt in current_stmts)
if dt < 0: if dt < 0:
# contains statement(s) with indeterminate duration # contains statement(s) with indeterminate duration
return None return None
if dt > 0: if dt > 0:
# advance timeline by dt # advance timeline by dt
for stmt in current_stmts: for stmt in current_stmts:
stmt.delay -= dt stmt.delay -= dt
if stmt.delay == 0: if stmt.delay == 0:
ref_stmt = stmt.stmt ref_stmt = stmt.stmt
delay_stmt = ast.copy_location( delay_stmt = ast.copy_location(
ast.Expr(ast.Call(func=ast.Name("delay", ast.Load()), ast.Expr(ast.Call(
args=[value_to_ast(dt)], func=ast.Name("delay", ast.Load()),
keywords=[], starargs=[], kwargs=[])), args=[value_to_ast(dt)],
ref_stmt) keywords=[], starargs=[], kwargs=[])),
r.append(delay_stmt) ref_stmt)
else: r.append(delay_stmt)
for stmt in current_stmts: else:
if stmt.delay == 0: for stmt in current_stmts:
r.append(stmt.stmt) if stmt.delay == 0:
# discard executed statements r.append(stmt.stmt)
exhausted_list = [] # discard executed statements
for stmt_i, stmt in enumerate(current_stmts): exhausted_list = []
if stmt.delay == 0: for stmt_i, stmt in enumerate(current_stmts):
try: if stmt.delay == 0:
stmt.stmt = next(stmt.it) try:
except StopIteration: stmt.stmt = next(stmt.it)
exhausted_list.append(stmt_i) except StopIteration:
else: exhausted_list.append(stmt_i)
stmt.delay = _get_duration(stmt.stmt) else:
for offset, i in enumerate(exhausted_list): stmt.delay = _get_duration(stmt.stmt)
current_stmts.pop(i-offset) for offset, i in enumerate(exhausted_list):
current_stmts.pop(i-offset)
return r
return r
def _interleave_stmts(stmts): def _interleave_stmts(stmts):
replacements = [] replacements = []
for stmt_i, stmt in enumerate(stmts): for stmt_i, stmt in enumerate(stmts):
if isinstance(stmt, (ast.For, ast.While, ast.If)): if isinstance(stmt, (ast.For, ast.While, ast.If)):
_interleave_stmts(stmt.body) _interleave_stmts(stmt.body)
_interleave_stmts(stmt.orelse) _interleave_stmts(stmt.orelse)
elif isinstance(stmt, ast.With): elif isinstance(stmt, ast.With):
btype = stmt.items[0].context_expr.id btype = stmt.items[0].context_expr.id
if btype == "sequential": if btype == "sequential":
_interleave_stmts(stmt.body) _interleave_stmts(stmt.body)
replacements.append((stmt_i, stmt.body)) replacements.append((stmt_i, stmt.body))
elif btype == "parallel": elif btype == "parallel":
timelines = [[s] for s in stmt.body] timelines = [[s] for s in stmt.body]
for timeline in timelines: for timeline in timelines:
_interleave_stmts(timeline) _interleave_stmts(timeline)
merged = _interleave_timelines(timelines) merged = _interleave_timelines(timelines)
if merged is not None: if merged is not None:
replacements.append((stmt_i, merged)) replacements.append((stmt_i, merged))
else: else:
raise ValueError("Unknown block type: " + btype) raise ValueError("Unknown block type: " + btype)
offset = 0 offset = 0
for location, new_stmts in replacements: for location, new_stmts in replacements:
stmts[offset+location:offset+location+1] = new_stmts stmts[offset+location:offset+location+1] = new_stmts
offset += len(new_stmts) - 1 offset += len(new_stmts) - 1
def interleave(funcdef): def interleave(funcdef):
_interleave_stmts(funcdef.body) _interleave_stmts(funcdef.body)

View File

@ -3,32 +3,34 @@ from llvm import passes as lp
from artiq.compiler import ir_infer_types, ir_ast_body, ir_values from artiq.compiler import ir_infer_types, ir_ast_body, ir_values
def compile_function(module, env, funcdef):
function_type = lc.Type.function(lc.Type.void(), [])
function = module.add_function(function_type, funcdef.name)
bb = function.append_basic_block("entry")
builder = lc.Builder.new(bb)
ns = ir_infer_types.infer_types(env, funcdef) def compile_function(module, env, funcdef):
for k, v in ns.items(): function_type = lc.Type.function(lc.Type.void(), [])
v.alloca(builder, k) function = module.add_function(function_type, funcdef.name)
visitor = ir_ast_body.Visitor(env, ns, builder) bb = function.append_basic_block("entry")
visitor.visit_statements(funcdef.body) builder = lc.Builder.new(bb)
builder.ret_void()
ns = ir_infer_types.infer_types(env, funcdef)
for k, v in ns.items():
v.alloca(builder, k)
visitor = ir_ast_body.Visitor(env, ns, builder)
visitor.visit_statements(funcdef.body)
builder.ret_void()
def get_runtime_binary(env, funcdef): def get_runtime_binary(env, funcdef):
module = lc.Module.new("main") module = lc.Module.new("main")
env.init_module(module) env.init_module(module)
ir_values.init_module(module) ir_values.init_module(module)
compile_function(module, env, funcdef) compile_function(module, env, funcdef)
pass_manager = lp.PassManager.new() pass_manager = lp.PassManager.new()
pass_manager.add(lp.PASS_MEM2REG) pass_manager.add(lp.PASS_MEM2REG)
pass_manager.add(lp.PASS_INSTCOMBINE) pass_manager.add(lp.PASS_INSTCOMBINE)
pass_manager.add(lp.PASS_REASSOCIATE) pass_manager.add(lp.PASS_REASSOCIATE)
pass_manager.add(lp.PASS_GVN) pass_manager.add(lp.PASS_GVN)
pass_manager.add(lp.PASS_SIMPLIFYCFG) pass_manager.add(lp.PASS_SIMPLIFYCFG)
pass_manager.run(module) pass_manager.run(module)
return env.emit_object() return env.emit_object()

View File

@ -2,187 +2,203 @@ import ast
from artiq.compiler import ir_values from artiq.compiler import ir_values
class Visitor: class Visitor:
def __init__(self, env, ns, builder=None): def __init__(self, env, ns, builder=None):
self.env = env self.env = env
self.ns = ns self.ns = ns
self.builder = builder self.builder = builder
# builder can be None for visit_expression # builder can be None for visit_expression
def visit_expression(self, node): def visit_expression(self, node):
method = "_visit_expr_" + node.__class__.__name__ method = "_visit_expr_" + node.__class__.__name__
try: try:
visitor = getattr(self, method) visitor = getattr(self, method)
except AttributeError: except AttributeError:
raise NotImplementedError("Unsupported node '{}' in expression".format(node.__class__.__name__)) raise NotImplementedError("Unsupported node '{}' in expression"
return visitor(node) .format(node.__class__.__name__))
return visitor(node)
def _visit_expr_Name(self, node): def _visit_expr_Name(self, node):
try: try:
r = self.ns[node.id] r = self.ns[node.id]
except KeyError: except KeyError:
raise NameError("Name '{}' is not defined".format(node.id)) raise NameError("Name '{}' is not defined".format(node.id))
return r return r
def _visit_expr_NameConstant(self, node): def _visit_expr_NameConstant(self, node):
v = node.value v = node.value
if v is None: if v is None:
r = ir_values.VNone() r = ir_values.VNone()
elif isinstance(v, bool): elif isinstance(v, bool):
r = ir_values.VBool() r = ir_values.VBool()
else: else:
raise NotImplementedError raise NotImplementedError
if self.builder is not None: if self.builder is not None:
r.set_const_value(self.builder, v) r.set_const_value(self.builder, v)
return r return r
def _visit_expr_Num(self, node): def _visit_expr_Num(self, node):
n = node.n n = node.n
if isinstance(n, int): if isinstance(n, int):
if abs(n) < 2**31: if abs(n) < 2**31:
r = ir_values.VInt() r = ir_values.VInt()
else: else:
r = ir_values.VInt(64) r = ir_values.VInt(64)
else: else:
raise NotImplementedError raise NotImplementedError
if self.builder is not None: if self.builder is not None:
r.set_const_value(self.builder, n) r.set_const_value(self.builder, n)
return r return r
def _visit_expr_UnaryOp(self, node): def _visit_expr_UnaryOp(self, node):
ast_unops = { ast_unops = {
ast.Invert: ir_values.operators.inv, ast.Invert: ir_values.operators.inv,
ast.Not: ir_values.operators.not_, ast.Not: ir_values.operators.not_,
ast.UAdd: ir_values.operators.pos, ast.UAdd: ir_values.operators.pos,
ast.USub: ir_values.operators.neg ast.USub: ir_values.operators.neg
} }
return ast_unops[type(node.op)](self.visit_expression(node.operand), self.builder) return ast_unops[type(node.op)](self.visit_expression(node.operand),
self.builder)
def _visit_expr_BinOp(self, node): def _visit_expr_BinOp(self, node):
ast_binops = { ast_binops = {
ast.Add: ir_values.operators.add, ast.Add: ir_values.operators.add,
ast.Sub: ir_values.operators.sub, ast.Sub: ir_values.operators.sub,
ast.Mult: ir_values.operators.mul, ast.Mult: ir_values.operators.mul,
ast.Div: ir_values.operators.truediv, ast.Div: ir_values.operators.truediv,
ast.FloorDiv: ir_values.operators.floordiv, ast.FloorDiv: ir_values.operators.floordiv,
ast.Mod: ir_values.operators.mod, ast.Mod: ir_values.operators.mod,
ast.Pow: ir_values.operators.pow, ast.Pow: ir_values.operators.pow,
ast.LShift: ir_values.operators.lshift, ast.LShift: ir_values.operators.lshift,
ast.RShift: ir_values.operators.rshift, ast.RShift: ir_values.operators.rshift,
ast.BitOr: ir_values.operators.or_, ast.BitOr: ir_values.operators.or_,
ast.BitXor: ir_values.operators.xor, ast.BitXor: ir_values.operators.xor,
ast.BitAnd: ir_values.operators.and_ ast.BitAnd: ir_values.operators.and_
} }
return ast_binops[type(node.op)](self.visit_expression(node.left), self.visit_expression(node.right), self.builder) return ast_binops[type(node.op)](self.visit_expression(node.left),
self.visit_expression(node.right),
self.builder)
def _visit_expr_Compare(self, node): def _visit_expr_Compare(self, node):
ast_cmps = { ast_cmps = {
ast.Eq: ir_values.operators.eq, ast.Eq: ir_values.operators.eq,
ast.NotEq: ir_values.operators.ne, ast.NotEq: ir_values.operators.ne,
ast.Lt: ir_values.operators.lt, ast.Lt: ir_values.operators.lt,
ast.LtE: ir_values.operators.le, ast.LtE: ir_values.operators.le,
ast.Gt: ir_values.operators.gt, ast.Gt: ir_values.operators.gt,
ast.GtE: ir_values.operators.ge ast.GtE: ir_values.operators.ge
} }
comparisons = [] comparisons = []
old_comparator = self.visit_expression(node.left) old_comparator = self.visit_expression(node.left)
for op, comparator_a in zip(node.ops, node.comparators): for op, comparator_a in zip(node.ops, node.comparators):
comparator = self.visit_expression(comparator_a) comparator = self.visit_expression(comparator_a)
comparison = ast_cmps[type(op)](old_comparator, comparator, self.builder) comparison = ast_cmps[type(op)](old_comparator, comparator,
comparisons.append(comparison) self.builder)
old_comparator = comparator comparisons.append(comparison)
r = comparisons[0] old_comparator = comparator
for comparison in comparisons[1:]: r = comparisons[0]
r = ir_values.operators.and_(r, comparison) for comparison in comparisons[1:]:
return r r = ir_values.operators.and_(r, comparison)
return r
def _visit_expr_Call(self, node): def _visit_expr_Call(self, node):
ast_unfuns = { ast_unfuns = {
"bool": ir_values.operators.bool, "bool": ir_values.operators.bool,
"int": ir_values.operators.int, "int": ir_values.operators.int,
"int64": ir_values.operators.int64, "int64": ir_values.operators.int64,
"round": ir_values.operators.round, "round": ir_values.operators.round,
"round64": ir_values.operators.round64, "round64": ir_values.operators.round64,
} }
fn = node.func.id fn = node.func.id
if fn in ast_unfuns: if fn in ast_unfuns:
return ast_unfuns[fn](self.visit_expression(node.args[0]), self.builder) return ast_unfuns[fn](self.visit_expression(node.args[0]),
elif fn == "Fraction": self.builder)
r = ir_values.VFraction() elif fn == "Fraction":
if self.builder is not None: r = ir_values.VFraction()
numerator = self.visit_expression(node.args[0]) if self.builder is not None:
denominator = self.visit_expression(node.args[1]) numerator = self.visit_expression(node.args[0])
r.set_value_nd(self.builder, numerator, denominator) denominator = self.visit_expression(node.args[1])
return r r.set_value_nd(self.builder, numerator, denominator)
elif fn == "syscall": return r
return self.env.syscall(node.args[0].s, elif fn == "syscall":
[self.visit_expression(expr) for expr in node.args[1:]], return self.env.syscall(
self.builder) node.args[0].s,
else: [self.visit_expression(expr) for expr in node.args[1:]],
raise NameError("Function '{}' is not defined".format(fn)) self.builder)
else:
raise NameError("Function '{}' is not defined".format(fn))
def visit_statements(self, stmts): def visit_statements(self, stmts):
for node in stmts: for node in stmts:
method = "_visit_stmt_" + node.__class__.__name__ method = "_visit_stmt_" + node.__class__.__name__
try: try:
visitor = getattr(self, method) visitor = getattr(self, method)
except AttributeError: except AttributeError:
raise NotImplementedError("Unsupported node '{}' in statement".format(node.__class__.__name__)) raise NotImplementedError("Unsupported node '{}' in statement"
visitor(node) .format(node.__class__.__name__))
visitor(node)
def _visit_stmt_Assign(self, node): def _visit_stmt_Assign(self, node):
val = self.visit_expression(node.value) val = self.visit_expression(node.value)
for target in node.targets: for target in node.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
self.ns[target.id].set_value(self.builder, val) self.ns[target.id].set_value(self.builder, val)
else: else:
raise NotImplementedError raise NotImplementedError
def _visit_stmt_AugAssign(self, node): def _visit_stmt_AugAssign(self, node):
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) val = self.visit_expression(ast.BinOp(op=node.op, left=node.target,
if isinstance(node.target, ast.Name): right=node.value))
self.ns[node.target.id].set_value(self.builder, val) if isinstance(node.target, ast.Name):
else: self.ns[node.target.id].set_value(self.builder, val)
raise NotImplementedError else:
raise NotImplementedError
def _visit_stmt_Expr(self, node): def _visit_stmt_Expr(self, node):
self.visit_expression(node.value) self.visit_expression(node.value)
def _visit_stmt_If(self, node): def _visit_stmt_If(self, node):
function = self.builder.basic_block.function function = self.builder.basic_block.function
then_block = function.append_basic_block("i_then") then_block = function.append_basic_block("i_then")
else_block = function.append_basic_block("i_else") else_block = function.append_basic_block("i_else")
merge_block = function.append_basic_block("i_merge") merge_block = function.append_basic_block("i_merge")
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) condition = ir_values.operators.bool(self.visit_expression(node.test),
self.builder.cbranch(condition.get_ssa_value(self.builder), then_block, else_block) self.builder)
self.builder.cbranch(condition.get_ssa_value(self.builder),
then_block, else_block)
self.builder.position_at_end(then_block) self.builder.position_at_end(then_block)
self.visit_statements(node.body) self.visit_statements(node.body)
self.builder.branch(merge_block) self.builder.branch(merge_block)
self.builder.position_at_end(else_block) self.builder.position_at_end(else_block)
self.visit_statements(node.orelse) self.visit_statements(node.orelse)
self.builder.branch(merge_block) self.builder.branch(merge_block)
self.builder.position_at_end(merge_block) self.builder.position_at_end(merge_block)
def _visit_stmt_While(self, node): def _visit_stmt_While(self, node):
function = self.builder.basic_block.function function = self.builder.basic_block.function
body_block = function.append_basic_block("w_body") body_block = function.append_basic_block("w_body")
else_block = function.append_basic_block("w_else") else_block = function.append_basic_block("w_else")
merge_block = function.append_basic_block("w_merge") merge_block = function.append_basic_block("w_merge")
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) condition = ir_values.operators.bool(
self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, else_block) self.visit_expression(node.test), self.builder)
self.builder.cbranch(
condition.get_ssa_value(self.builder), body_block, else_block)
self.builder.position_at_end(body_block) self.builder.position_at_end(body_block)
self.visit_statements(node.body) self.visit_statements(node.body)
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) condition = ir_values.operators.bool(
self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, merge_block) self.visit_expression(node.test), self.builder)
self.builder.cbranch(
condition.get_ssa_value(self.builder), body_block, merge_block)
self.builder.position_at_end(else_block) self.builder.position_at_end(else_block)
self.visit_statements(node.orelse) self.visit_statements(node.orelse)
self.builder.branch(merge_block) self.builder.branch(merge_block)
self.builder.position_at_end(merge_block) self.builder.position_at_end(merge_block)

View File

@ -4,46 +4,49 @@ from copy import deepcopy
from artiq.compiler.ir_ast_body import Visitor from artiq.compiler.ir_ast_body import Visitor
class _TypeScanner(ast.NodeVisitor): class _TypeScanner(ast.NodeVisitor):
def __init__(self, env, ns): def __init__(self, env, ns):
self.exprv = Visitor(env, ns) self.exprv = Visitor(env, ns)
def visit_Assign(self, node): def visit_Assign(self, node):
val = self.exprv.visit_expression(node.value) val = self.exprv.visit_expression(node.value)
ns = self.exprv.ns ns = self.exprv.ns
for target in node.targets: for target in node.targets:
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
if target.id in ns: if target.id in ns:
ns[target.id].merge(val) ns[target.id].merge(val)
else: else:
ns[target.id] = val ns[target.id] = val
else: else:
raise NotImplementedError raise NotImplementedError
def visit_AugAssign(self, node):
val = self.exprv.visit_expression(ast.BinOp(
op=node.op, left=node.target, right=node.value))
ns = self.exprv.ns
target = node.target
if isinstance(target, ast.Name):
if target.id in ns:
ns[target.id].merge(val)
else:
ns[target.id] = val
else:
raise NotImplementedError
def visit_AugAssign(self, node):
val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value))
ns = self.exprv.ns
target = node.target
if isinstance(target, ast.Name):
if target.id in ns:
ns[target.id].merge(val)
else:
ns[target.id] = val
else:
raise NotImplementedError
def infer_types(env, node): def infer_types(env, node):
ns = dict() ns = dict()
while True: while True:
prev_ns = deepcopy(ns) prev_ns = deepcopy(ns)
ts = _TypeScanner(env, ns) ts = _TypeScanner(env, ns)
ts.visit(node) ts.visit(node)
if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()): if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()):
# no more promotions - completed # no more promotions - completed
return ns return ns
if __name__ == "__main__": if __name__ == "__main__":
testcode = """ testcode = """
a = 2 # promoted later to int64 a = 2 # promoted later to int64
b = a + 1 # initially int32, becomes int64 after a is promoted b = a + 1 # initially int32, becomes int64 after a is promoted
c = b//2 # initially int32, becomes int64 after b is promoted c = b//2 # initially int32, becomes int64 after b is promoted
@ -53,6 +56,6 @@ a += x # promotes a to int64
foo = True foo = True
bar = None bar = None
""" """
ns = infer_types(None, ast.parse(testcode)) ns = infer_types(None, ast.parse(testcode))
for k, v in sorted(ns.items(), key=itemgetter(0)): for k, v in sorted(ns.items(), key=itemgetter(0)):
print("{:10}--> {}".format(k, str(v))) print("{:10}--> {}".format(k, str(v)))

View File

@ -2,404 +2,450 @@ from types import SimpleNamespace
from llvm import core as lc from llvm import core as lc
class _Value: class _Value:
def __init__(self): def __init__(self):
self._llvm_value = None self._llvm_value = None
def get_ssa_value(self, builder): def get_ssa_value(self, builder):
if isinstance(self._llvm_value, lc.AllocaInstruction): if isinstance(self._llvm_value, lc.AllocaInstruction):
return builder.load(self._llvm_value) return builder.load(self._llvm_value)
else: else:
return self._llvm_value return self._llvm_value
def set_ssa_value(self, builder, value): def set_ssa_value(self, builder, value):
if self._llvm_value is None: if self._llvm_value is None:
self._llvm_value = value self._llvm_value = value
elif isinstance(self._llvm_value, lc.AllocaInstruction): elif isinstance(self._llvm_value, lc.AllocaInstruction):
builder.store(value, self._llvm_value) builder.store(value, self._llvm_value)
else: else:
raise RuntimeError("Attempted to set LLVM SSA value multiple times") raise RuntimeError(
"Attempted to set LLVM SSA value multiple times")
def alloca(self, builder, name): def alloca(self, builder, name):
if self._llvm_value is not None: if self._llvm_value is not None:
raise RuntimeError("Attempted to alloca existing LLVM value") raise RuntimeError("Attempted to alloca existing LLVM value")
self._llvm_value = builder.alloca(self.get_llvm_type(), name=name) self._llvm_value = builder.alloca(self.get_llvm_type(), name=name)
def o_int(self, builder): def o_int(self, builder):
return self.o_intx(32, builder) return self.o_intx(32, builder)
def o_int64(self, builder): def o_int64(self, builder):
return self.o_intx(64, builder) return self.o_intx(64, builder)
def o_round(self, builder): def o_round(self, builder):
return self.o_roundx(32, builder) return self.o_roundx(32, builder)
def o_round64(self, builder):
return self.o_roundx(64, builder)
def o_round64(self, builder):
return self.o_roundx(64, builder)
# None type # None type
class VNone(_Value): class VNone(_Value):
def __repr__(self): def __repr__(self):
return "<VNone>" return "<VNone>"
def get_llvm_type(self): def get_llvm_type(self):
return lc.Type.void() return lc.Type.void()
def same_type(self, other): def same_type(self, other):
return isinstance(other, VNone) return isinstance(other, VNone)
def merge(self, other): def merge(self, other):
if not isinstance(other, VNone): if not isinstance(other, VNone):
raise TypeError raise TypeError
def alloca(self, builder, name): def alloca(self, builder, name):
pass pass
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_const_value(builder, False)
return r
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_const_value(builder, False)
return r
# Integer type # Integer type
class VInt(_Value): class VInt(_Value):
def __init__(self, nbits=32): def __init__(self, nbits=32):
_Value.__init__(self) _Value.__init__(self)
self.nbits = nbits self.nbits = nbits
def get_llvm_type(self): def get_llvm_type(self):
return lc.Type.int(self.nbits) return lc.Type.int(self.nbits)
def __repr__(self): def __repr__(self):
return "<VInt:{}>".format(self.nbits) return "<VInt:{}>".format(self.nbits)
def same_type(self, other): def same_type(self, other):
return isinstance(other, VInt) and other.nbits == self.nbits return isinstance(other, VInt) and other.nbits == self.nbits
def merge(self, other): def merge(self, other):
if isinstance(other, VInt) and not isinstance(other, VBool): if isinstance(other, VInt) and not isinstance(other, VBool):
if other.nbits > self.nbits: if other.nbits > self.nbits:
self.nbits = other.nbits self.nbits = other.nbits
else: else:
raise TypeError raise TypeError
def set_value(self, builder, n): def set_value(self, builder, n):
self.set_ssa_value(builder, n.o_intx(self.nbits, builder).get_ssa_value(builder)) self.set_ssa_value(
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
def set_const_value(self, builder, n): def set_const_value(self, builder, n):
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n)) self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
def o_bool(self, builder): def o_bool(self, builder):
r = VBool() r = VBool()
if builder is not None: if builder is not None:
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, r.set_ssa_value(
self.get_ssa_value(builder), lc.Constant.int(self.get_llvm_type(), 0))) builder, builder.icmp(
return r lc.ICMP_NE,
self.get_ssa_value(builder),
lc.Constant.int(self.get_llvm_type(), 0)))
return r
def o_intx(self, target_bits, builder):
r = VInt(target_bits)
if builder is not None:
if self.nbits == target_bits:
r.set_ssa_value(
builder, self.get_ssa_value(builder))
if self.nbits > target_bits:
r.set_ssa_value(
builder, builder.trunc(self.get_ssa_value(builder),
r.get_llvm_type()))
if self.nbits < target_bits:
r.set_ssa_value(
builder, builder.sext(self.get_ssa_value(builder),
r.get_llvm_type()))
return r
o_roundx = o_intx
def o_intx(self, target_bits, builder):
r = VInt(target_bits)
if builder is not None:
if self.nbits == target_bits:
r.set_ssa_value(builder, self.get_ssa_value(builder))
if self.nbits > target_bits:
r.set_ssa_value(builder, builder.trunc(self.get_ssa_value(builder), r.get_llvm_type()))
if self.nbits < target_bits:
r.set_ssa_value(builder, builder.sext(self.get_ssa_value(builder), r.get_llvm_type()))
return r
o_roundx = o_intx
def _make_vint_binop_method(builder_name): def _make_vint_binop_method(builder_name):
def binop_method(self, other, builder): def binop_method(self, other, builder):
if isinstance(other, VInt): if isinstance(other, VInt):
target_bits = max(self.nbits, other.nbits) target_bits = max(self.nbits, other.nbits)
r = VInt(target_bits) r = VInt(target_bits)
if builder is not None: if builder is not None:
left = self.o_intx(target_bits, builder) left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder) right = other.o_intx(target_bits, builder)
bf = getattr(builder, builder_name) bf = getattr(builder, builder_name)
r.set_ssa_value(builder, r.set_ssa_value(
bf(left.get_ssa_value(builder), right.get_ssa_value(builder))) builder, bf(left.get_ssa_value(builder),
return r right.get_ssa_value(builder)))
else: return r
return NotImplemented else:
return binop_method return NotImplemented
return binop_method
for _method_name, _builder_name in (("o_add", "add"),
("o_sub", "sub"),
("o_mul", "mul"),
("o_floordiv", "sdiv"),
("o_mod", "srem"),
("o_and", "and_"),
("o_xor", "xor"),
("o_or", "or_")):
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
for _method_name, _builder_name in (
("o_add", "add"),
("o_sub", "sub"),
("o_mul", "mul"),
("o_floordiv", "sdiv"),
("o_mod", "srem"),
("o_and", "and_"),
("o_xor", "xor"),
("o_or", "or_")):
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
def _make_vint_cmp_method(icmp_val): def _make_vint_cmp_method(icmp_val):
def cmp_method(self, other, builder): def cmp_method(self, other, builder):
if isinstance(other, VInt): if isinstance(other, VInt):
r = VBool() r = VBool()
if builder is not None: if builder is not None:
target_bits = max(self.nbits, other.nbits) target_bits = max(self.nbits, other.nbits)
left = self.o_intx(target_bits, builder) left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder) right = other.o_intx(target_bits, builder)
r.set_ssa_value(builder, r.set_ssa_value(
builder.icmp(icmp_val, left.get_ssa_value(builder), right.get_ssa_value(builder))) builder,
return r builder.icmp(
else: icmp_val, left.get_ssa_value(builder),
return NotImplemented right.get_ssa_value(builder)))
return cmp_method return r
else:
return NotImplemented
return cmp_method
for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ),
("o_ne", lc.ICMP_NE),
("o_lt", lc.ICMP_SLT),
("o_le", lc.ICMP_SLE),
("o_gt", lc.ICMP_SGT),
("o_ge", lc.ICMP_SGE)):
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
for _method_name, _icmp_val in (
("o_eq", lc.ICMP_EQ),
("o_ne", lc.ICMP_NE),
("o_lt", lc.ICMP_SLT),
("o_le", lc.ICMP_SLE),
("o_gt", lc.ICMP_SGT),
("o_ge", lc.ICMP_SGE)):
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
# Boolean type # Boolean type
class VBool(VInt): class VBool(VInt):
def __init__(self): def __init__(self):
VInt.__init__(self, 1) VInt.__init__(self, 1)
def __repr__(self): def __repr__(self):
return "<VBool>" return "<VBool>"
def same_type(self, other): def same_type(self, other):
return isinstance(other, VBool) return isinstance(other, VBool)
def merge(self, other): def merge(self, other):
if not isinstance(other, VBool): if not isinstance(other, VBool):
raise TypeError raise TypeError
def set_const_value(self, builder, b): def set_const_value(self, builder, b):
VInt.set_const_value(self, builder, int(b)) VInt.set_const_value(self, builder, int(b))
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_ssa_value(builder, self.get_ssa_value(builder))
return r
def o_bool(self, builder):
r = VBool()
if builder is not None:
r.set_ssa_value(builder, self.get_ssa_value(builder))
return r
# Fraction type # Fraction type
def _gcd64(builder, a, b): def _gcd64(builder, a, b):
gcd_f = builder.module.get_function_named("__gcd64") gcd_f = builder.module.get_function_named("__gcd64")
return builder.call(gcd_f, [a, b]) return builder.call(gcd_f, [a, b])
def _frac_normalize(builder, numerator, denominator): def _frac_normalize(builder, numerator, denominator):
gcd = _gcd64(numerator, denominator) gcd = _gcd64(numerator, denominator)
numerator = builder.sdiv(numerator, gcd) numerator = builder.sdiv(numerator, gcd)
denominator = builder.sdiv(denominator, gcd) denominator = builder.sdiv(denominator, gcd)
return numerator, denominator return numerator, denominator
def _frac_make_ssa(builder, numerator, denominator): def _frac_make_ssa(builder, numerator, denominator):
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2)) value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0)) value = builder.insert_element(
value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1)) value, numerator, lc.Constant.int(lc.Type.int(), 0))
return value value = builder.insert_element(
value, denominator, lc.Constant.int(lc.Type.int(), 1))
return value
class VFraction(_Value): class VFraction(_Value):
def get_llvm_type(self): def get_llvm_type(self):
return lc.Type.vector(lc.Type.int(64), 2) return lc.Type.vector(lc.Type.int(64), 2)
def __repr__(self): def __repr__(self):
return "<VFraction>" return "<VFraction>"
def same_type(self, other): def same_type(self, other):
return isinstance(other, VFraction) return isinstance(other, VFraction)
def merge(self, other): def merge(self, other):
if not isinstance(other, VFraction): if not isinstance(other, VFraction):
raise TypeError raise TypeError
def _nd(self, builder, invert=False): def _nd(self, builder, invert=False):
ssa_value = self.get_ssa_value(builder) ssa_value = self.get_ssa_value(builder)
numerator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 0)) numerator = builder.extract_element(
denominator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 1)) ssa_value, lc.Constant.int(lc.Type.int(), 0))
if invert: denominator = builder.extract_element(
return denominator, numerator ssa_value, lc.Constant.int(lc.Type.int(), 1))
else: if invert:
return numerator, denominator return denominator, numerator
else:
return numerator, denominator
def set_value_nd(self, builder, numerator, denominator): def set_value_nd(self, builder, numerator, denominator):
numerator = numerator.o_int64(builder).get_ssa_value(builder) numerator = numerator.o_int64(builder).get_ssa_value(builder)
denominator = denominator.o_int64(builder).get_ssa_value(builder) denominator = denominator.o_int64(builder).get_ssa_value(builder)
numerator, denominator = _frac_normalize(builder, numerator, denominator) numerator, denominator = _frac_normalize(
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator)) builder, numerator, denominator)
self.set_ssa_value(
builder, _frac_make_ssa(builder, numerator, denominator))
def set_value(self, builder, n): def set_value(self, builder, n):
if not isinstance(n, VFraction): if not isinstance(n, VFraction):
raise TypeError raise TypeError
self.set_ssa_value(builder, n.get_ssa_value(builder)) self.set_ssa_value(builder, n.get_ssa_value(builder))
def o_bool(self, builder): def o_bool(self, builder):
r = VBool() r = VBool()
if builder is not None: if builder is not None:
zero = lc.Constant.int(lc.Type.int(64), 0) zero = lc.Constant.int(lc.Type.int(64), 0)
numerator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0)) numerator = builder.extract_element(
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero)) self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
return r r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero))
return r
def o_intx(self, target_bits, builder): def o_intx(self, target_bits, builder):
if builder is None: if builder is None:
return VInt(target_bits) return VInt(target_bits)
else: else:
r = VInt(64) r = VInt(64)
numerator, denominator = self._nd(builder) numerator, denominator = self._nd(builder)
r.set_ssa_value(builder, builder.sdiv(numerator, denominator)) r.set_ssa_value(builder, builder.sdiv(numerator, denominator))
return r.o_intx(target_bits, builder) return r.o_intx(target_bits, builder)
def o_roundx(self, target_bits, builder): def o_roundx(self, target_bits, builder):
if builder is None: if builder is None:
return VInt(target_bits) return VInt(target_bits)
else: else:
r = VInt(64) r = VInt(64)
numerator, denominator = self._nd(builder) numerator, denominator = self._nd(builder)
h_denominator = builder.ashr(denominator, lc.Constant.int(lc.Type.int(), 1)) h_denominator = builder.ashr(denominator,
r_numerator = builder.add(numerator, h_denominator) lc.Constant.int(lc.Type.int(), 1))
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator)) r_numerator = builder.add(numerator, h_denominator)
return r.o_intx(target_bits, builder) r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
return r.o_intx(target_bits, builder)
def _o_eq_inv(self, other, builder, ne): def _o_eq_inv(self, other, builder, ne):
if isinstance(other, VFraction): if isinstance(other, VFraction):
r = VBool() r = VBool()
if builder is not None: if builder is not None:
ee = [] ee = []
for i in range(2): for i in range(2):
es = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), i)) es = builder.extract_element(
eo = builder.extract_element(other.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), i)) self.get_ssa_value(builder),
ee.append(builder.icmp(lc.ICMP_EQ, es, eo)) lc.Constant.int(lc.Type.int(), i))
ssa_r = builder.and_(ee[0], ee[1]) eo = builder.extract_element(
if ne: other.get_ssa_value(builder),
ssa_r = builder.xor(ssa_r, lc.Constant.int(lc.Type.int(1), 1)) lc.Constant.int(lc.Type.int(), i))
r.set_ssa_value(builder, ssa_r) ee.append(builder.icmp(lc.ICMP_EQ, es, eo))
return r ssa_r = builder.and_(ee[0], ee[1])
else: if ne:
return NotImplemented ssa_r = builder.xor(ssa_r,
lc.Constant.int(lc.Type.int(1), 1))
r.set_ssa_value(builder, ssa_r)
return r
else:
return NotImplemented
def o_eq(self, other, builder): def o_eq(self, other, builder):
return self._o_eq_inv(other, builder, False) return self._o_eq_inv(other, builder, False)
def o_ne(self, other, builder): def o_ne(self, other, builder):
return self._o_eq_inv(other, builder, True) return self._o_eq_inv(other, builder, True)
def _o_muldiv(self, other, builder, div, invert=False): def _o_muldiv(self, other, builder, div, invert=False):
r = VFraction() r = VFraction()
if isinstance(other, VInt): if isinstance(other, VInt):
if builder is None: if builder is None:
return r return r
else: else:
numerator, denominator = self._nd(builder, invert) numerator, denominator = self._nd(builder, invert)
i = other.get_ssa_value(builder) i = other.get_ssa_value(builder)
if div: if div:
gcd = _gcd64(i, numerator) gcd = _gcd64(i, numerator)
i = builder.sdiv(i, gcd) i = builder.sdiv(i, gcd)
numerator = builder.sdiv(numerator, gcd) numerator = builder.sdiv(numerator, gcd)
denominator = builder.mul(denominator, i) denominator = builder.mul(denominator, i)
else: else:
gcd = _gcd64(i, denominator) gcd = _gcd64(i, denominator)
i = builder.sdiv(i, gcd) i = builder.sdiv(i, gcd)
denominator = builder.sdiv(denominator, gcd) denominator = builder.sdiv(denominator, gcd)
numerator = builder.mul(numerator, i) numerator = builder.mul(numerator, i)
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator)) self.set_ssa_value(builder, _frac_make_ssa(builder, numerator,
elif isinstance(other, VFraction): denominator))
if builder is None: elif isinstance(other, VFraction):
return r if builder is None:
else: return r
numerator, denominator = self._nd(builder, invert) else:
onumerator, odenominator = other._nd(builder) numerator, denominator = self._nd(builder, invert)
if div: onumerator, odenominator = other._nd(builder)
numerator = builder.mul(numerator, odenominator) if div:
denominator = builder.mul(denominator, onumerator) numerator = builder.mul(numerator, odenominator)
else: denominator = builder.mul(denominator, onumerator)
numerator = builder.mul(numerator, onumerator) else:
denominator = builder.mul(denominator, odenominator) numerator = builder.mul(numerator, onumerator)
numerator, denominator = _frac_normalize(builder, numerator, denominator) denominator = builder.mul(denominator, odenominator)
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator)) numerator, denominator = _frac_normalize(builder, numerator,
else: denominator)
return NotImplemented self.set_ssa_value(
builder, _frac_make_ssa(builder, numerator, denominator))
else:
return NotImplemented
def o_mul(self, other, builder): def o_mul(self, other, builder):
return self._o_muldiv(other, builder, False) return self._o_muldiv(other, builder, False)
def o_truediv(self, other, builder): def o_truediv(self, other, builder):
return self._o_muldiv(other, builder, True) return self._o_muldiv(other, builder, True)
def or_mul(self, other, builder): def or_mul(self, other, builder):
return self._o_muldiv(other, builder, False) return self._o_muldiv(other, builder, False)
def or_truediv(self, other, builder): def or_truediv(self, other, builder):
return self._o_muldiv(other, builder, False, True) return self._o_muldiv(other, builder, False, True)
def o_floordiv(self, other, builder): def o_floordiv(self, other, builder):
r = self.o_truediv(other, builder) r = self.o_truediv(other, builder)
if r is NotImplemented: if r is NotImplemented:
return r return r
else: else:
return r.o_int(builder) return r.o_int(builder)
def or_floordiv(self, other, builder):
r = self.or_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)
def or_floordiv(self, other, builder):
r = self.or_truediv(other, builder)
if r is NotImplemented:
return r
else:
return r.o_int(builder)
# Operators # Operators
def _make_unary_operator(op_name): def _make_unary_operator(op_name):
def op(x, builder): def op(x, builder):
try: try:
opf = getattr(x, "o_"+op_name) opf = getattr(x, "o_"+op_name)
except AttributeError: except AttributeError:
raise TypeError("Unsupported operand type for {}: {}".format(op_name, type(x).__name__)) raise TypeError(
return opf(builder) "Unsupported operand type for {}: {}"
return op .format(op_name, type(x).__name__))
return opf(builder)
return op
def _make_binary_operator(op_name): def _make_binary_operator(op_name):
def op(l, r, builder): def op(l, r, builder):
try: try:
opf = getattr(l, "o_"+op_name) opf = getattr(l, "o_"+op_name)
except AttributeError: except AttributeError:
result = NotImplemented result = NotImplemented
else: else:
result = opf(r, builder) result = opf(r, builder)
if result is NotImplemented: if result is NotImplemented:
try: try:
ropf = getattr(r, "or_"+op_name) ropf = getattr(r, "or_"+op_name)
except AttributeError: except AttributeError:
result = NotImplemented result = NotImplemented
else: else:
result = ropf(l, builder) result = ropf(l, builder)
if result is NotImplemented: if result is NotImplemented:
raise TypeError("Unsupported operand types for {}: {} and {}".format( raise TypeError(
op_name, type(l).__name__, type(r).__name__)) "Unsupported operand types for {}: {} and {}"
return result .format(op_name, type(l).__name__, type(r).__name__))
return op return result
return op
def _make_operators(): def _make_operators():
d = dict() d = dict()
for op_name in ("bool", "int", "int64", "round", "round64", "inv", "pos", "neg"): for op_name in ("bool", "int", "int64", "round", "round64",
d[op_name] = _make_unary_operator(op_name) "inv", "pos", "neg"):
d["not_"] = _make_binary_operator("not") d[op_name] = _make_unary_operator(op_name)
for op_name in ("add", "sub", "mul", d["not_"] = _make_binary_operator("not")
"truediv", "floordiv", "mod", for op_name in ("add", "sub", "mul",
"pow", "lshift", "rshift", "xor", "truediv", "floordiv", "mod",
"eq", "ne", "lt", "le", "gt", "ge"): "pow", "lshift", "rshift", "xor",
d[op_name] = _make_binary_operator(op_name) "eq", "ne", "lt", "le", "gt", "ge"):
d["and_"] = _make_binary_operator("and") d[op_name] = _make_binary_operator(op_name)
d["or_"] = _make_binary_operator("or") d["and_"] = _make_binary_operator("and")
return SimpleNamespace(**d) d["or_"] = _make_binary_operator("or")
return SimpleNamespace(**d)
operators = _make_operators() operators = _make_operators()
def init_module(module): def init_module(module):
func_type = lc.Type.function(lc.Type.int(64), func_type = lc.Type.function(
[lc.Type.int(64), lc.Type.int(64)]) lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)])
module.add_function(func_type, "__gcd64") module.add_function(func_type, "__gcd64")

View File

@ -3,41 +3,48 @@ import ast
from artiq.compiler.tools import value_to_ast from artiq.compiler.tools import value_to_ast
from artiq.language.core import int64 from artiq.language.core import int64
def _insert_int64(node): def _insert_int64(node):
return ast.copy_location( return ast.copy_location(
ast.Call(func=ast.Name("int64", ast.Load()), ast.Call(func=ast.Name("int64", ast.Load()),
args=[node], args=[node],
keywords=[], starargs=[], kwargs=[]), node) keywords=[], starargs=[], kwargs=[]),
node)
class _TimeLowerer(ast.NodeTransformer): class _TimeLowerer(ast.NodeTransformer):
def visit_Call(self, node): def visit_Call(self, node):
if isinstance(node.func, ast.Name) and node.func.id == "now": if isinstance(node.func, ast.Name) and node.func.id == "now":
return ast.copy_location(ast.Name("now", ast.Load()), node) return ast.copy_location(ast.Name("now", ast.Load()), node)
else: else:
self.generic_visit(node) self.generic_visit(node)
return node return node
def visit_Expr(self, node):
self.generic_visit(node)
if (isinstance(node.value, ast.Call)
and isinstance(node.value.func, ast.Name)):
funcname = node.value.func.id
if funcname == "delay":
return ast.copy_location(
ast.AugAssign(target=ast.Name("now", ast.Store()),
op=ast.Add(),
value=_insert_int64(node.value.args[0])),
node)
elif funcname == "at":
return ast.copy_location(
ast.Assign(targets=[ast.Name("now", ast.Store())],
value=_insert_int64(node.value.args[0])),
node)
else:
return node
else:
return node
def visit_Expr(self, node):
self.generic_visit(node)
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name):
funcname = node.value.func.id
if funcname == "delay":
return ast.copy_location(
ast.AugAssign(target=ast.Name("now", ast.Store()), op=ast.Add(),
value=_insert_int64(node.value.args[0])),
node)
elif funcname == "at":
return ast.copy_location(
ast.Assign(targets=[ast.Name("now", ast.Store())],
value=_insert_int64(node.value.args[0])),
node)
else:
return node
else:
return node
def lower_time(funcdef, initial_time): def lower_time(funcdef, initial_time):
_TimeLowerer().visit(funcdef) _TimeLowerer().visit(funcdef)
funcdef.body.insert(0, ast.copy_location( funcdef.body.insert(0, ast.copy_location(
ast.Assign(targets=[ast.Name("now", ast.Store())], value=value_to_ast(int64(initial_time))), ast.Assign(targets=[ast.Name("now", ast.Store())],
funcdef)) value=value_to_ast(int64(initial_time))),
funcdef))

View File

@ -3,6 +3,7 @@ import ast
from artiq.compiler.tools import value_to_ast from artiq.compiler.tools import value_to_ast
from artiq.language import units from artiq.language import units
# TODO: # TODO:
# * track variable and expression dimensions # * track variable and expression dimensions
# * raise exception on dimension errors in expressions # * raise exception on dimension errors in expressions
@ -11,32 +12,36 @@ from artiq.language import units
# e.g. foo = now() + 1*us [...] at(foo) # e.g. foo = now() + 1*us [...] at(foo)
class _UnitsLowerer(ast.NodeTransformer): class _UnitsLowerer(ast.NodeTransformer):
def __init__(self, ref_period): def __init__(self, ref_period):
self.ref_period = ref_period self.ref_period = ref_period
self.in_core_time = False self.in_core_time = False
def visit_Call(self, node):
fn = node.func.id
if fn in ("delay", "at"):
old_in_core_time = self.in_core_time
self.in_core_time = True
self.generic_visit(node)
self.in_core_time = old_in_core_time
elif fn == "Quantity":
if self.in_core_time:
if node.args[1].id == "microcycle_units":
node = node.args[0]
else:
node = ast.copy_location(
ast.BinOp(left=node.args[0],
op=ast.Div(),
right=value_to_ast(self.ref_period)),
node)
else:
node = node.args[0]
else:
self.generic_visit(node)
return node
def visit_Call(self, node):
fn = node.func.id
if fn in ("delay", "at"):
old_in_core_time = self.in_core_time
self.in_core_time = True
self.generic_visit(node)
self.in_core_time = old_in_core_time
elif fn == "Quantity":
if self.in_core_time:
if node.args[1].id == "microcycle_units":
node = node.args[0]
else:
node = ast.copy_location(
ast.BinOp(left=node.args[0], op=ast.Div(), right=value_to_ast(self.ref_period)),
node)
else:
node = node.args[0]
else:
self.generic_visit(node)
return node
def lower_units(funcdef, ref_period): def lower_units(funcdef, ref_period):
if not isinstance(ref_period, units.Quantity) or ref_period.unit is not units.s_unit: if (not isinstance(ref_period, units.Quantity)
raise units.DimensionError("Reference period not expressed in seconds") or ref_period.unit is not units.s_unit):
_UnitsLowerer(ref_period.amount).visit(funcdef) raise units.DimensionError("Reference period not expressed in seconds")
_UnitsLowerer(ref_period.amount).visit(funcdef)

View File

@ -4,60 +4,67 @@ from fractions import Fraction
from artiq.language import core as core_language from artiq.language import core as core_language
from artiq.language import units from artiq.language import units
def eval_ast(expr, symdict=dict()): def eval_ast(expr, symdict=dict()):
if not isinstance(expr, ast.Expression): if not isinstance(expr, ast.Expression):
expr = ast.copy_location(ast.Expression(expr), expr) expr = ast.copy_location(ast.Expression(expr), expr)
ast.fix_missing_locations(expr) ast.fix_missing_locations(expr)
code = compile(expr, "<ast>", "eval") code = compile(expr, "<ast>", "eval")
return eval(code, symdict) return eval(code, symdict)
def value_to_ast(value): def value_to_ast(value):
if isinstance(value, core_language.int64): # must be before int if isinstance(value, core_language.int64): # must be before int
return ast.Call( return ast.Call(
func=ast.Name("int64", ast.Load()), func=ast.Name("int64", ast.Load()),
args=[ast.Num(int(value))], args=[ast.Num(int(value))],
keywords=[], starargs=None, kwargs=None) keywords=[], starargs=None, kwargs=None)
elif isinstance(value, int): elif isinstance(value, int):
return ast.Num(value) return ast.Num(value)
elif isinstance(value, Fraction): elif isinstance(value, Fraction):
return ast.Call(func=ast.Name("Fraction", ast.Load()), return ast.Call(
args=[ast.Num(value.numerator), ast.Num(value.denominator)], func=ast.Name("Fraction", ast.Load()),
keywords=[], starargs=None, kwargs=None) args=[ast.Num(value.numerator), ast.Num(value.denominator)],
elif isinstance(value, str): keywords=[], starargs=None, kwargs=None)
return ast.Str(value) elif isinstance(value, str):
else: return ast.Str(value)
for kg in core_language.kernel_globals: else:
if value is getattr(core_language, kg): for kg in core_language.kernel_globals:
return ast.Name(kg, ast.Load()) if value is getattr(core_language, kg):
if isinstance(value, units.Quantity): return ast.Name(kg, ast.Load())
return ast.Call( if isinstance(value, units.Quantity):
func=ast.Name("Quantity", ast.Load()), return ast.Call(
args=[value_to_ast(value.amount), ast.Name(value.unit.name+"_unit", ast.Load())], func=ast.Name("Quantity", ast.Load()),
keywords=[], starargs=None, kwargs=None) args=[value_to_ast(value.amount),
return None ast.Name(value.unit.name+"_unit", ast.Load())],
keywords=[], starargs=None, kwargs=None)
return None
class NotConstant(Exception): class NotConstant(Exception):
pass pass
def eval_constant(node): def eval_constant(node):
if isinstance(node, ast.Num): if isinstance(node, ast.Num):
return node.n return node.n
elif isinstance(node, ast.Str): elif isinstance(node, ast.Str):
return node.s return node.s
elif isinstance(node, ast.Call): elif isinstance(node, ast.Call):
funcname = node.func.id funcname = node.func.id
if funcname == "Fraction": if funcname == "Fraction":
numerator, denominator = eval_constant(node.args[0]), eval_constant(node.args[1]) numerator = eval_constant(node.args[0])
return Fraction(numerator, denominator) denominator = eval_constant(node.args[1])
elif funcname == "Quantity": return Fraction(numerator, denominator)
amount, unit = node.args elif funcname == "Quantity":
amount = eval_constant(amount) amount, unit = node.args
try: amount = eval_constant(amount)
unit = getattr(units, unit.id) try:
except: unit = getattr(units, unit.id)
raise NotConstant except:
return units.Quantity(amount, unit) raise NotConstant
else: return units.Quantity(amount, unit)
raise NotConstant else:
else: raise NotConstant
raise NotConstant else:
raise NotConstant

File diff suppressed because it is too large Load Diff

View File

@ -2,46 +2,51 @@ import ast
from artiq.compiler.tools import eval_ast, value_to_ast from artiq.compiler.tools import eval_ast, value_to_ast
def _count_stmts(node): def _count_stmts(node):
if isinstance(node, (ast.For, ast.While, ast.If)): if isinstance(node, (ast.For, ast.While, ast.If)):
return 1 + _count_stmts(node.body) + _count_stmts(node.orelse) return 1 + _count_stmts(node.body) + _count_stmts(node.orelse)
elif isinstance(node, ast.With): elif isinstance(node, ast.With):
return 1 + _count_stmts(node.body) return 1 + _count_stmts(node.body)
elif isinstance(node, list): elif isinstance(node, list):
return sum(map(_count_stmts, node)) return sum(map(_count_stmts, node))
else: else:
return 1 return 1
class _LoopUnroller(ast.NodeTransformer): class _LoopUnroller(ast.NodeTransformer):
def __init__(self, limit): def __init__(self, limit):
self.limit = limit self.limit = limit
def visit_For(self, node):
self.generic_visit(node)
try:
it = eval_ast(node.iter)
except:
return node
l_it = len(it)
if l_it:
n = l_it*_count_stmts(node.body)
if n < self.limit:
replacement = []
for i in it:
if not isinstance(i, int):
replacement = None
break
replacement.append(ast.copy_location(
ast.Assign(targets=[node.target],
value=value_to_ast(i)),
node))
replacement += node.body
if replacement is not None:
return replacement
else:
return node
else:
return node
else:
return node.orelse
def visit_For(self, node):
self.generic_visit(node)
try:
it = eval_ast(node.iter)
except:
return node
l_it = len(it)
if l_it:
n = l_it*_count_stmts(node.body)
if n < self.limit:
replacement = []
for i in it:
if not isinstance(i, int):
replacement = None
break
replacement.append(ast.copy_location(
ast.Assign(targets=[node.target], value=value_to_ast(i)), node))
replacement += node.body
if replacement is not None:
return replacement
else:
return node
else:
return node
else:
return node.orelse
def unroll_loops(node, limit): def unroll_loops(node, limit):
_LoopUnroller(limit).visit(node) _LoopUnroller(limit).visit(node)

View File

@ -6,22 +6,23 @@ from artiq.compiler.interleave import interleave
from artiq.compiler.lower_time import lower_time from artiq.compiler.lower_time import lower_time
from artiq.compiler.ir import get_runtime_binary from artiq.compiler.ir import get_runtime_binary
class Core: class Core:
def __init__(self, core_com, runtime_env=None): def __init__(self, core_com, runtime_env=None):
if runtime_env is None: if runtime_env is None:
runtime_env = core_com.get_runtime_env() runtime_env = core_com.get_runtime_env()
self.runtime_env = runtime_env self.runtime_env = runtime_env
self.core_com = core_com self.core_com = core_com
def run(self, k_function, k_args, k_kwargs): def run(self, k_function, k_args, k_kwargs):
funcdef, rpc_map = inline(self, k_function, k_args, k_kwargs) funcdef, rpc_map = inline(self, k_function, k_args, k_kwargs)
lower_units(funcdef, self.runtime_env.ref_period) lower_units(funcdef, self.runtime_env.ref_period)
fold_constants(funcdef) fold_constants(funcdef)
unroll_loops(funcdef, 50) unroll_loops(funcdef, 50)
interleave(funcdef) interleave(funcdef)
lower_time(funcdef, getattr(self.runtime_env, "initial_time", 0)) lower_time(funcdef, getattr(self.runtime_env, "initial_time", 0))
fold_constants(funcdef) fold_constants(funcdef)
binary = get_runtime_binary(self.runtime_env, funcdef) binary = get_runtime_binary(self.runtime_env, funcdef)
self.core_com.run(binary) self.core_com.run(binary)
self.core_com.serve(rpc_map) self.core_com.serve(rpc_map)

View File

@ -3,26 +3,28 @@ from operator import itemgetter
from artiq.devices.runtime import LinkInterface from artiq.devices.runtime import LinkInterface
from artiq.language.units import ns from artiq.language.units import ns
class _RuntimeEnvironment(LinkInterface):
def __init__(self, ref_period):
self.ref_period = ref_period
def emit_object(self): class _RuntimeEnvironment(LinkInterface):
return str(self.module) def __init__(self, ref_period):
self.ref_period = ref_period
def emit_object(self):
return str(self.module)
class CoreCom: class CoreCom:
def get_runtime_env(self): def get_runtime_env(self):
return _RuntimeEnvironment(10*ns) return _RuntimeEnvironment(10*ns)
def run(self, kcode): def run(self, kcode):
print("================") print("================")
print(" LLVM IR") print(" LLVM IR")
print("================") print("================")
print(kcode) print(kcode)
def serve(self, rpc_map): def serve(self, rpc_map):
print("================") print("================")
print(" RPC map") print(" RPC map")
print("================") print("================")
for k, v in sorted(rpc_map.items(), key=itemgetter(0)): for k, v in sorted(rpc_map.items(), key=itemgetter(0)):
print(str(k)+" -> "+str(v)) print(str(k)+" -> "+str(v))

View File

@ -1,111 +1,126 @@
import os, termios, struct, zlib import os
import termios
import struct
import zlib
from enum import Enum from enum import Enum
from artiq.language import units from artiq.language import units
from artiq.devices.runtime import Environment from artiq.devices.runtime import Environment
class UnsupportedDevice(Exception): class UnsupportedDevice(Exception):
pass pass
class _MsgType(Enum): class _MsgType(Enum):
REQUEST_IDENT = 0x01 REQUEST_IDENT = 0x01
LOAD_KERNEL = 0x02 LOAD_KERNEL = 0x02
KERNEL_FINISHED = 0x03 KERNEL_FINISHED = 0x03
RPC_REQUEST = 0x04 RPC_REQUEST = 0x04
def _write_exactly(f, data): def _write_exactly(f, data):
remaining = len(data) remaining = len(data)
pos = 0 pos = 0
while remaining: while remaining:
written = f.write(data[pos:]) written = f.write(data[pos:])
remaining -= written remaining -= written
pos += written pos += written
def _read_exactly(f, n): def _read_exactly(f, n):
r = bytes() r = bytes()
while(len(r) < n): while(len(r) < n):
r += f.read(n - len(r)) r += f.read(n - len(r))
return r return r
class CoreCom: class CoreCom:
def __init__(self, dev="/dev/ttyUSB1", baud=115200): def __init__(self, dev="/dev/ttyUSB1", baud=115200):
self._fd = os.open(dev, os.O_RDWR | os.O_NOCTTY) self._fd = os.open(dev, os.O_RDWR | os.O_NOCTTY)
self.port = os.fdopen(self._fd, "r+b", buffering=0) self.port = os.fdopen(self._fd, "r+b", buffering=0)
iflag, oflag, cflag, lflag, ispeed, ospeed, cc = \ iflag, oflag, cflag, lflag, ispeed, ospeed, cc = \
termios.tcgetattr(self._fd) termios.tcgetattr(self._fd)
iflag = termios.IGNBRK | termios.IGNPAR iflag = termios.IGNBRK | termios.IGNPAR
oflag = 0 oflag = 0
cflag |= termios.CLOCAL | termios.CREAD | termios.CS8 cflag |= termios.CLOCAL | termios.CREAD | termios.CS8
lflag = 0 lflag = 0
ispeed = ospeed = getattr(termios, "B"+str(baud)) ispeed = ospeed = getattr(termios, "B"+str(baud))
cc[termios.VMIN] = 1 cc[termios.VMIN] = 1
cc[termios.VTIME] = 0 cc[termios.VTIME] = 0
termios.tcsetattr(self._fd, termios.TCSANOW, [ termios.tcsetattr(self._fd, termios.TCSANOW, [
iflag, oflag, cflag, lflag, ispeed, ospeed, cc]) iflag, oflag, cflag, lflag, ispeed, ospeed, cc])
termios.tcdrain(self._fd) termios.tcdrain(self._fd)
termios.tcflush(self._fd, termios.TCOFLUSH) termios.tcflush(self._fd, termios.TCOFLUSH)
termios.tcflush(self._fd, termios.TCIFLUSH) termios.tcflush(self._fd, termios.TCIFLUSH)
def close(self): def close(self):
self.port.close() self.port.close()
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
self.close() self.close()
def get_runtime_env(self): def get_runtime_env(self):
_write_exactly(self.port, struct.pack(">lb", 0x5a5a5a5a, _MsgType.REQUEST_IDENT.value)) _write_exactly(self.port, struct.pack(
# FIXME: when loading immediately after a board reset, we erroneously get some zeros back. ">lb", 0x5a5a5a5a, _MsgType.REQUEST_IDENT.value))
# Ignore them with a warning for now. # FIXME: when loading immediately after a board reset,
spurious_zero_count = 0 # we erroneously get some zeros back.
while True: # Ignore them with a warning for now.
(reply, ) = struct.unpack("b", _read_exactly(self.port, 1)) spurious_zero_count = 0
if reply == 0: while True:
spurious_zero_count += 1 (reply, ) = struct.unpack("b", _read_exactly(self.port, 1))
else: if reply == 0:
break spurious_zero_count += 1
if spurious_zero_count: else:
print("Warning: received {} spurious zeros".format(spurious_zero_count)) break
runtime_id = chr(reply) if spurious_zero_count:
for i in range(3): print("Warning: received {} spurious zeros"
(reply, ) = struct.unpack("b", _read_exactly(self.port, 1)) .format(spurious_zero_count))
runtime_id += chr(reply) runtime_id = chr(reply)
if runtime_id != "AROR": for i in range(3):
raise UnsupportedDevice("Unsupported runtime ID: "+runtime_id) (reply, ) = struct.unpack("b", _read_exactly(self.port, 1))
(ref_period, ) = struct.unpack(">l", _read_exactly(self.port, 4)) runtime_id += chr(reply)
return Environment(ref_period*units.ps) if runtime_id != "AROR":
raise UnsupportedDevice("Unsupported runtime ID: "+runtime_id)
(ref_period, ) = struct.unpack(">l", _read_exactly(self.port, 4))
return Environment(ref_period*units.ps)
def run(self, kcode): def run(self, kcode):
_write_exactly(self.port, struct.pack(">lblL", _write_exactly(self.port, struct.pack(
0x5a5a5a5a, _MsgType.LOAD_KERNEL.value, len(kcode), zlib.crc32(kcode))) ">lblL",
_write_exactly(self.port, kcode) 0x5a5a5a5a, _MsgType.LOAD_KERNEL.value,
(reply, ) = struct.unpack("b", _read_exactly(self.port, 1)) len(kcode), zlib.crc32(kcode)))
if reply != 0x4f: _write_exactly(self.port, kcode)
raise IOError("Incorrect reply from device: "+hex(reply)) (reply, ) = struct.unpack("b", _read_exactly(self.port, 1))
if reply != 0x4f:
raise IOError("Incorrect reply from device: "+hex(reply))
def _wait_sync(self): def _wait_sync(self):
recognized = 0 recognized = 0
while recognized < 4: while recognized < 4:
(c, ) = struct.unpack("b", _read_exactly(self.port, 1)) (c, ) = struct.unpack("b", _read_exactly(self.port, 1))
if c == 0x5a: if c == 0x5a:
recognized += 1 recognized += 1
else: else:
recognized = 0 recognized = 0
def serve(self, rpc_map): def serve(self, rpc_map):
while True: while True:
self._wait_sync() self._wait_sync()
msg = _MsgType(*struct.unpack("b", _read_exactly(self.port, 1))) msg = _MsgType(*struct.unpack("b", _read_exactly(self.port, 1)))
if msg == _MsgType.KERNEL_FINISHED: if msg == _MsgType.KERNEL_FINISHED:
return return
elif msg == _MsgType.RPC_REQUEST: elif msg == _MsgType.RPC_REQUEST:
rpc_num, n_args = struct.unpack(">hb", _read_exactly(self.port, 3)) rpc_num, n_args = struct.unpack(">hb",
args = [] _read_exactly(self.port, 3))
for i in range(n_args): args = []
args.append(*struct.unpack(">l", _read_exactly(self.port, 4))) for i in range(n_args):
r = rpc_map[rpc_num](*args) args.append(*struct.unpack(">l",
if r is None: _read_exactly(self.port, 4)))
r = 0 r = rpc_map[rpc_num](*args)
_write_exactly(self.port, struct.pack(">l", r)) if r is None:
r = 0
_write_exactly(self.port, struct.pack(">l", r))

View File

@ -1,20 +1,22 @@
from artiq.language.core import * from artiq.language.core import *
from artiq.language.units import * from artiq.language.units import *
class DDS(AutoContext): class DDS(AutoContext):
parameters = "dds_sysclk reg_channel rtio_channel" parameters = "dds_sysclk reg_channel rtio_channel"
def build(self): def build(self):
self._previous_frequency = 0*MHz self._previous_frequency = 0*MHz
kernel_attr = "_previous_frequency" kernel_attr = "_previous_frequency"
@kernel @kernel
def pulse(self, frequency, duration): def pulse(self, frequency, duration):
if self._previous_frequency != frequency: if self._previous_frequency != frequency:
syscall("rtio_sync", self.rtio_channel) # wait until output is off syscall("rtio_sync", self.rtio_channel) # wait until output is off
syscall("dds_program", self.reg_channel, int(2**32*frequency/self.dds_sysclk)) syscall("dds_program", self.reg_channel,
self._previous_frequency = frequency int(2**32*frequency/self.dds_sysclk))
syscall("rtio_set", now(), self.rtio_channel, 1) self._previous_frequency = frequency
delay(duration) syscall("rtio_set", now(), self.rtio_channel, 1)
syscall("rtio_set", now(), self.rtio_channel, 0) delay(duration)
syscall("rtio_set", now(), self.rtio_channel, 0)

View File

@ -1,8 +1,9 @@
from artiq.language.core import * from artiq.language.core import *
class GPIOOut(AutoContext):
parameters = "channel"
@kernel class GPIOOut(AutoContext):
def set(self, level): parameters = "channel"
syscall("gpio_set", self.channel, level)
@kernel
def set(self, level):
syscall("gpio_set", self.channel, level)

View File

@ -3,70 +3,77 @@ from llvm import target as lt
from artiq.compiler import ir_values from artiq.compiler import ir_values
lt.initialize_all() lt.initialize_all()
_syscalls = { _syscalls = {
"rpc": "i+:i", "rpc": "i+:i",
"gpio_set": "ii:n", "gpio_set": "ii:n",
"rtio_set": "Iii:n", "rtio_set": "Iii:n",
"rtio_sync": "i:n", "rtio_sync": "i:n",
"dds_program": "ii:n", "dds_program": "ii:n",
} }
_chr_to_type = { _chr_to_type = {
"n": lambda: lc.Type.void(), "n": lambda: lc.Type.void(),
"i": lambda: lc.Type.int(32), "i": lambda: lc.Type.int(32),
"I": lambda: lc.Type.int(64) "I": lambda: lc.Type.int(64)
} }
_chr_to_value = { _chr_to_value = {
"n": lambda: ir_values.VNone(), "n": lambda: ir_values.VNone(),
"i": lambda: ir_values.VInt(), "i": lambda: ir_values.VInt(),
"I": lambda: ir_values.VInt(64) "I": lambda: ir_values.VInt(64)
} }
def _str_to_functype(s):
assert(s[-2] == ":")
type_ret = _chr_to_type[s[-1]]()
var_arg_fixcount = None def _str_to_functype(s):
type_args = [] assert(s[-2] == ":")
for n, c in enumerate(s[:-2]): type_ret = _chr_to_type[s[-1]]()
if c == "+":
type_args.append(lc.Type.int()) var_arg_fixcount = None
var_arg_fixcount = n type_args = []
else: for n, c in enumerate(s[:-2]):
type_args.append(_chr_to_type[c]()) if c == "+":
return var_arg_fixcount, lc.Type.function(type_ret, type_args, var_arg=var_arg_fixcount is not None) type_args.append(lc.Type.int())
var_arg_fixcount = n
else:
type_args.append(_chr_to_type[c]())
return (var_arg_fixcount,
lc.Type.function(type_ret, type_args,
var_arg=var_arg_fixcount is not None))
class LinkInterface: class LinkInterface:
def init_module(self, module): def init_module(self, module):
self.module = module self.module = module
self.var_arg_fixcount = dict() self.var_arg_fixcount = dict()
for func_name, func_type_str in _syscalls.items(): for func_name, func_type_str in _syscalls.items():
var_arg_fixcount, func_type = _str_to_functype(func_type_str) var_arg_fixcount, func_type = _str_to_functype(func_type_str)
if var_arg_fixcount is not None: if var_arg_fixcount is not None:
self.var_arg_fixcount[func_name] = var_arg_fixcount self.var_arg_fixcount[func_name] = var_arg_fixcount
self.module.add_function(func_type, "__syscall_"+func_name) self.module.add_function(func_type, "__syscall_"+func_name)
def syscall(self, syscall_name, args, builder):
r = _chr_to_value[_syscalls[syscall_name][-1]]()
if builder is not None:
args = [arg.get_ssa_value(builder) for arg in args]
if syscall_name in self.var_arg_fixcount:
fixcount = self.var_arg_fixcount[syscall_name]
args = args[:fixcount] \
+ [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \
+ args[fixcount:]
llvm_function = self.module.get_function_named(
"__syscall_" + syscall_name)
r.set_ssa_value(builder, builder.call(llvm_function, args))
return r
def syscall(self, syscall_name, args, builder):
r = _chr_to_value[_syscalls[syscall_name][-1]]()
if builder is not None:
args = [arg.get_ssa_value(builder) for arg in args]
if syscall_name in self.var_arg_fixcount:
fixcount = self.var_arg_fixcount[syscall_name]
args = args[:fixcount] \
+ [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \
+ args[fixcount:]
llvm_function = self.module.get_function_named("__syscall_"+syscall_name)
r.set_ssa_value(builder, builder.call(llvm_function, args))
return r
class Environment(LinkInterface): class Environment(LinkInterface):
def __init__(self, ref_period): def __init__(self, ref_period):
self.ref_period = ref_period self.ref_period = ref_period
self.initial_time = 2000 self.initial_time = 2000
def emit_object(self): def emit_object(self):
tm = lt.TargetMachine.new(triple="or1k", cpu="generic") tm = lt.TargetMachine.new(triple="or1k", cpu="generic")
return tm.emit_object(self.module) return tm.emit_object(self.module)

View File

@ -1,10 +1,11 @@
from artiq.language.core import * from artiq.language.core import *
class TTLOut(AutoContext):
parameters = "channel"
@kernel class TTLOut(AutoContext):
def pulse(self, duration): parameters = "channel"
syscall("rtio_set", now(), self.channel, 1)
delay(duration) @kernel
syscall("rtio_set", now(), self.channel, 0) def pulse(self, duration):
syscall("rtio_set", now(), self.channel, 1)
delay(duration)
syscall("rtio_set", now(), self.channel, 0)

View File

@ -3,150 +3,169 @@ from fractions import Fraction
from artiq.language import units from artiq.language import units
class int64(int): class int64(int):
pass pass
def _make_int64_op_method(int_method): def _make_int64_op_method(int_method):
def method(self, *args): def method(self, *args):
r = int_method(self, *args) r = int_method(self, *args)
if isinstance(r, int): if isinstance(r, int):
r = int64(r) r = int64(r)
return r return r
return method return method
for _op_name in ( for _op_name in ("neg", "pos", "abs", "invert", "round",
"neg", "pos", "abs", "invert", "round", "add", "radd", "sub", "rsub", "mul", "rmul", "pow", "rpow",
"add", "radd", "sub", "rsub", "mul", "rmul", "pow", "rpow", "lshift", "rlshift", "rshift", "rrshift",
"lshift", "rlshift", "rshift", "rrshift", "and", "rand", "xor", "rxor", "or", "ror",
"and", "rand", "xor", "rxor", "or", "ror", "floordiv", "rfloordiv", "mod", "rmod"):
"floordiv", "rfloordiv", "mod", "rmod"): method_name = "__" + _op_name + "__"
method_name = "__" + _op_name + "__" orig_method = getattr(int, method_name)
orig_method = getattr(int, method_name) setattr(int64, method_name, _make_int64_op_method(orig_method))
setattr(int64, method_name, _make_int64_op_method(orig_method))
for _op_name in ("add", "sub", "mul", "floordiv", "mod",
"pow", "lshift", "rshift", "lshift",
"and", "xor", "or"):
op_method = getattr(int, "__" + _op_name + "__")
setattr(int64, "__i" + _op_name + "__", _make_int64_op_method(op_method))
for _op_name in (
"add", "sub", "mul", "floordiv", "mod",
"pow", "lshift", "rshift", "lshift",
"and", "xor", "or"):
op_method = getattr(int, "__" + _op_name + "__")
setattr(int64, "__i" + _op_name + "__", _make_int64_op_method(op_method))
def round64(x): def round64(x):
return int64(round(x)) return int64(round(x))
def _make_kernel_ro(value): def _make_kernel_ro(value):
return isinstance(value, (bool, int, int64, float, Fraction, units.Quantity)) return isinstance(
value, (bool, int, int64, float, Fraction, units.Quantity))
class AutoContext: class AutoContext:
parameters = "" parameters = ""
implicit_core = True implicit_core = True
def __init__(self, mvs=None, **kwargs): def __init__(self, mvs=None, **kwargs):
kernel_attr_ro = [] kernel_attr_ro = []
self.mvs = mvs self.mvs = mvs
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(self, k, v) setattr(self, k, v)
if _make_kernel_ro(v): if _make_kernel_ro(v):
kernel_attr_ro.append(k) kernel_attr_ro.append(k)
parameters = self.parameters.split() parameters = self.parameters.split()
if self.implicit_core: if self.implicit_core:
parameters.append("core") parameters.append("core")
for parameter in parameters: for parameter in parameters:
try: try:
value = getattr(self, parameter) value = getattr(self, parameter)
except AttributeError: except AttributeError:
value = self.mvs.get_missing_value(parameter) value = self.mvs.get_missing_value(parameter)
setattr(self, parameter, value) setattr(self, parameter, value)
if _make_kernel_ro(value): if _make_kernel_ro(value):
kernel_attr_ro.append(parameter) kernel_attr_ro.append(parameter)
self.kernel_attr_ro = " ".join(kernel_attr_ro)
self.build() self.kernel_attr_ro = " ".join(kernel_attr_ro)
def get_missing_value(self, parameter): self.build()
try:
return getattr(self, parameter) def get_missing_value(self, parameter):
except AttributeError: try:
return self.mvs.get_missing_value(parameter) return getattr(self, parameter)
except AttributeError:
return self.mvs.get_missing_value(parameter)
def build(self):
""" Overload this function to add sub-experiments"""
pass
def build(self):
""" Overload this function to add sub-experiments"""
pass
KernelFunctionInfo = namedtuple("KernelFunctionInfo", "core_name k_function") KernelFunctionInfo = namedtuple("KernelFunctionInfo", "core_name k_function")
def kernel(arg): def kernel(arg):
if isinstance(arg, str): if isinstance(arg, str):
def real_decorator(k_function): def real_decorator(k_function):
def run_on_core(exp, *k_args, **k_kwargs): def run_on_core(exp, *k_args, **k_kwargs):
getattr(exp, arg).run(k_function, ((exp,) + k_args), k_kwargs) getattr(exp, arg).run(k_function, ((exp,) + k_args), k_kwargs)
run_on_core.k_function_info = KernelFunctionInfo(core_name=arg, k_function=k_function) run_on_core.k_function_info = KernelFunctionInfo(
return run_on_core core_name=arg, k_function=k_function)
return real_decorator return run_on_core
else: return real_decorator
def run_on_core(exp, *k_args, **k_kwargs): else:
exp.core.run(arg, ((exp,) + k_args), k_kwargs) def run_on_core(exp, *k_args, **k_kwargs):
run_on_core.k_function_info = KernelFunctionInfo(core_name="core", k_function=arg) exp.core.run(arg, ((exp,) + k_args), k_kwargs)
return run_on_core run_on_core.k_function_info = KernelFunctionInfo(
core_name="core", k_function=arg)
return run_on_core
class _DummyTimeManager: class _DummyTimeManager:
def _not_implemented(self, *args, **kwargs): def _not_implemented(self, *args, **kwargs):
raise NotImplementedError("Attempted to interpret kernel without a time manager") raise NotImplementedError(
"Attempted to interpret kernel without a time manager")
enter_sequential = _not_implemented enter_sequential = _not_implemented
enter_parallel = _not_implemented enter_parallel = _not_implemented
exit = _not_implemented exit = _not_implemented
take_time = _not_implemented take_time = _not_implemented
get_time = _not_implemented get_time = _not_implemented
set_time = _not_implemented set_time = _not_implemented
_time_manager = _DummyTimeManager() _time_manager = _DummyTimeManager()
def set_time_manager(time_manager): def set_time_manager(time_manager):
global _time_manager global _time_manager
_time_manager = time_manager _time_manager = time_manager
class _DummySyscallManager: class _DummySyscallManager:
def do(self, *args): def do(self, *args):
raise NotImplementedError("Attempted to interpret kernel without a syscall manager") raise NotImplementedError(
"Attempted to interpret kernel without a syscall manager")
_syscall_manager = _DummySyscallManager() _syscall_manager = _DummySyscallManager()
def set_syscall_manager(syscall_manager): def set_syscall_manager(syscall_manager):
global _syscall_manager global _syscall_manager
_syscall_manager = syscall_manager _syscall_manager = syscall_manager
# global namespace for kernels # global namespace for kernels
kernel_globals = "sequential", "parallel", "delay", "now", "at", "syscall" kernel_globals = "sequential", "parallel", "delay", "now", "at", "syscall"
class _Sequential:
def __enter__(self):
_time_manager.enter_sequential()
def __exit__(self, type, value, traceback): class _Sequential:
_time_manager.exit() def __enter__(self):
_time_manager.enter_sequential()
def __exit__(self, type, value, traceback):
_time_manager.exit()
sequential = _Sequential() sequential = _Sequential()
class _Parallel:
def __enter__(self):
_time_manager.enter_parallel()
def __exit__(self, type, value, traceback): class _Parallel:
_time_manager.exit() def __enter__(self):
_time_manager.enter_parallel()
def __exit__(self, type, value, traceback):
_time_manager.exit()
parallel = _Parallel() parallel = _Parallel()
def delay(duration): def delay(duration):
_time_manager.take_time(duration) _time_manager.take_time(duration)
def now(): def now():
return _time_manager.get_time() return _time_manager.get_time()
def at(time): def at(time):
_time_manager.set_time(time) _time_manager.set_time(time)
def syscall(*args): def syscall(*args):
return _syscall_manager.do(*args) return _syscall_manager.do(*args)

View File

@ -1,122 +1,139 @@
from collections import namedtuple from collections import namedtuple
from fractions import Fraction from fractions import Fraction
_prefixes_str = "pnum_kMG" _prefixes_str = "pnum_kMG"
_smallest_prefix = Fraction(1, 10**12) _smallest_prefix = Fraction(1, 10**12)
Unit = namedtuple("Unit", "name") Unit = namedtuple("Unit", "name")
class DimensionError(Exception): class DimensionError(Exception):
pass pass
class Quantity: class Quantity:
def __init__(self, amount, unit): def __init__(self, amount, unit):
self.amount = amount self.amount = amount
self.unit = unit self.unit = unit
def __repr__(self): def __repr__(self):
r_amount = self.amount r_amount = self.amount
if isinstance(r_amount, int) or isinstance(r_amount, Fraction): if isinstance(r_amount, int) or isinstance(r_amount, Fraction):
r_prefix = 0 r_prefix = 0
r_amount = r_amount/_smallest_prefix r_amount = r_amount/_smallest_prefix
if r_amount: if r_amount:
numerator = r_amount.numerator numerator = r_amount.numerator
while numerator % 1000 == 0 and r_prefix < len(_prefixes_str): while numerator % 1000 == 0 and r_prefix < len(_prefixes_str):
numerator /= 1000 numerator /= 1000
r_amount /= 1000 r_amount /= 1000
r_prefix += 1 r_prefix += 1
prefix_str = _prefixes_str[r_prefix] prefix_str = _prefixes_str[r_prefix]
if prefix_str == "_": if prefix_str == "_":
prefix_str = "" prefix_str = ""
return str(r_amount) + " " + prefix_str + self.unit.name return str(r_amount) + " " + prefix_str + self.unit.name
else: else:
return str(r_amount) + " " + self.unit.name return str(r_amount) + " " + self.unit.name
def __mul__(self, other): # mul/div
if isinstance(other, Quantity): def __mul__(self, other):
return NotImplemented if isinstance(other, Quantity):
return Quantity(self.amount*other, self.unit) return NotImplemented
def __rmul__(self, other): return Quantity(self.amount*other, self.unit)
if isinstance(other, Quantity):
return NotImplemented
return Quantity(other*self.amount, self.unit)
def __truediv__(self, other):
if isinstance(other, Quantity):
if other.unit == self.unit:
return self.amount/other.amount
else:
return NotImplemented
else:
return Quantity(self.amount/other, self.unit)
def __floordiv__(self, other):
if isinstance(other, Quantity):
if other.unit == self.unit:
return self.amount//other.amount
else:
return NotImplemented
else:
return Quantity(self.amount//other, self.unit)
def __neg__(self): def __rmul__(self, other):
return Quantity(-self.amount, self.unit) if isinstance(other, Quantity):
return NotImplemented
return Quantity(other*self.amount, self.unit)
def __add__(self, other): def __truediv__(self, other):
if self.unit != other.unit: if isinstance(other, Quantity):
raise DimensionError if other.unit == self.unit:
return Quantity(self.amount + other.amount, self.unit) return self.amount/other.amount
def __radd__(self, other): else:
if self.unit != other.unit: return NotImplemented
raise DimensionError else:
return Quantity(other.amount + self.amount, self.unit) return Quantity(self.amount/other, self.unit)
def __sub__(self, other):
if self.unit != other.unit:
raise DimensionError
return Quantity(self.amount - other.amount, self.unit)
def __rsub__(self, other):
if self.unit != other.unit:
raise DimensionError
return Quantity(other.amount - self.amount, self.unit)
def __lt__(self, other): def __floordiv__(self, other):
if self.unit != other.unit: if isinstance(other, Quantity):
raise DimensionError if other.unit == self.unit:
return self.amount < other.amount return self.amount//other.amount
def __le__(self, other): else:
if self.unit != other.unit: return NotImplemented
raise DimensionError else:
return self.amount <= other.amount return Quantity(self.amount//other, self.unit)
def __eq__(self, other):
if self.unit != other.unit: # unary ops
raise DimensionError def __neg__(self):
return self.amount == other.amount return Quantity(-self.amount, self.unit)
def __ne__(self, other):
if self.unit != other.unit: def __pos__(self):
raise DimensionError return Quantity(self.amount, self.unit)
return self.amount != other.amount
def __gt__(self, other): # add/sub
if self.unit != other.unit: def __add__(self, other):
raise DimensionError if self.unit != other.unit:
return self.amount > other.amount raise DimensionError
def __ge__(self, other): return Quantity(self.amount + other.amount, self.unit)
if self.unit != other.unit:
raise DimensionError def __radd__(self, other):
return self.amount >= other.amount if self.unit != other.unit:
raise DimensionError
return Quantity(other.amount + self.amount, self.unit)
def __sub__(self, other):
if self.unit != other.unit:
raise DimensionError
return Quantity(self.amount - other.amount, self.unit)
def __rsub__(self, other):
if self.unit != other.unit:
raise DimensionError
return Quantity(other.amount - self.amount, self.unit)
# comparisons
def __lt__(self, other):
if self.unit != other.unit:
raise DimensionError
return self.amount < other.amount
def __le__(self, other):
if self.unit != other.unit:
raise DimensionError
return self.amount <= other.amount
def __eq__(self, other):
if self.unit != other.unit:
raise DimensionError
return self.amount == other.amount
def __ne__(self, other):
if self.unit != other.unit:
raise DimensionError
return self.amount != other.amount
def __gt__(self, other):
if self.unit != other.unit:
raise DimensionError
return self.amount > other.amount
def __ge__(self, other):
if self.unit != other.unit:
raise DimensionError
return self.amount >= other.amount
def check_unit(value, unit):
if not isinstance(value, Quantity) or value.unit != unit:
raise DimensionError
return value.amount
def _register_unit(name, prefixes): def _register_unit(name, prefixes):
unit = Unit(name) unit = Unit(name)
globals()[name+"_unit"] = unit globals()[name+"_unit"] = unit
amount = _smallest_prefix amount = _smallest_prefix
for prefix in _prefixes_str: for prefix in _prefixes_str:
if prefix in prefixes: if prefix in prefixes:
quantity = Quantity(amount, unit) quantity = Quantity(amount, unit)
full_name = prefix + name if prefix != "_" else name full_name = prefix + name if prefix != "_" else name
globals()[full_name] = quantity globals()[full_name] = quantity
amount *= 1000 amount *= 1000
_register_unit("s", "pnum_") _register_unit("s", "pnum_")
_register_unit("Hz", "_kMG") _register_unit("Hz", "_kMG")

View File

@ -4,39 +4,43 @@ from artiq.language.core import AutoContext, delay
from artiq.language import units from artiq.language import units
from artiq.sim import time from artiq.sim import time
class Core: class Core:
def run(self, k_function, k_args, k_kwargs): def run(self, k_function, k_args, k_kwargs):
return k_function(*k_args, **k_kwargs) return k_function(*k_args, **k_kwargs)
class Input(AutoContext): class Input(AutoContext):
parameters = "name" parameters = "name"
implicit_core = False implicit_core = False
def build(self): def build(self):
self.prng = Random() self.prng = Random()
def wait_edge(self): def wait_edge(self):
duration = self.prng.randrange(0, 20)*units.ms duration = self.prng.randrange(0, 20)*units.ms
time.manager.event(("wait_edge", self.name, duration)) time.manager.event(("wait_edge", self.name, duration))
delay(duration) delay(duration)
def count_gate(self, duration):
result = self.prng.randrange(0, 100)
time.manager.event(("count_gate", self.name, duration, result))
delay(duration)
return result
def count_gate(self, duration):
result = self.prng.randrange(0, 100)
time.manager.event(("count_gate", self.name, duration, result))
delay(duration)
return result
class WaveOutput(AutoContext): class WaveOutput(AutoContext):
parameters = "name" parameters = "name"
implicit_core = False implicit_core = False
def pulse(self, frequency, duration):
time.manager.event(("pulse", self.name, frequency, duration))
delay(duration)
def pulse(self, frequency, duration):
time.manager.event(("pulse", self.name, frequency, duration))
delay(duration)
class VoltageOutput(AutoContext): class VoltageOutput(AutoContext):
parameters = "name" parameters = "name"
implicit_core = False implicit_core = False
def set(self, value): def set(self, value):
time.manager.event(("set_voltage", self.name, value)) time.manager.event(("set_voltage", self.name, value))

View File

@ -3,66 +3,69 @@ from operator import itemgetter
from artiq.language.units import * from artiq.language.units import *
from artiq.language import core as core_language from artiq.language import core as core_language
class SequentialTimeContext:
def __init__(self, current_time):
self.current_time = current_time
self.block_duration = 0*s
def take_time(self, amount): class SequentialTimeContext:
self.current_time += amount def __init__(self, current_time):
self.block_duration += amount self.current_time = current_time
self.block_duration = 0*s
def take_time(self, amount):
self.current_time += amount
self.block_duration += amount
class ParallelTimeContext: class ParallelTimeContext:
def __init__(self, current_time): def __init__(self, current_time):
self.current_time = current_time self.current_time = current_time
self.block_duration = 0*s self.block_duration = 0*s
def take_time(self, amount):
if amount > self.block_duration:
self.block_duration = amount
def take_time(self, amount):
if amount > self.block_duration:
self.block_duration = amount
class Manager: class Manager:
def __init__(self): def __init__(self):
self.stack = [SequentialTimeContext(0*s)] self.stack = [SequentialTimeContext(0*s)]
self.timeline = [] self.timeline = []
def enter_sequential(self): def enter_sequential(self):
new_context = SequentialTimeContext(self.get_time()) new_context = SequentialTimeContext(self.get_time())
self.stack.append(new_context) self.stack.append(new_context)
def enter_parallel(self): def enter_parallel(self):
new_context = ParallelTimeContext(self.get_time()) new_context = ParallelTimeContext(self.get_time())
self.stack.append(new_context) self.stack.append(new_context)
def exit(self): def exit(self):
old_context = self.stack.pop() old_context = self.stack.pop()
self.take_time(old_context.block_duration) self.take_time(old_context.block_duration)
def take_time(self, duration): def take_time(self, duration):
self.stack[-1].take_time(duration) self.stack[-1].take_time(duration)
def get_time(self): def get_time(self):
return self.stack[-1].current_time return self.stack[-1].current_time
def set_time(self, t): def set_time(self, t):
dt = t - self.get_time() dt = t - self.get_time()
if dt < 0*s: if dt < 0*s:
raise ValueError("Attempted to go back in time") raise ValueError("Attempted to go back in time")
self.take_time(dt) self.take_time(dt)
def event(self, description): def event(self, description):
self.timeline.append((self.get_time(), description)) self.timeline.append((self.get_time(), description))
def format_timeline(self): def format_timeline(self):
r = "" r = ""
prev_time = 0*s prev_time = 0*s
for time, description in sorted(self.timeline, key=itemgetter(0)): for time, description in sorted(self.timeline, key=itemgetter(0)):
r += "@{:10} (+{:10}) ".format(str(time), str(time-prev_time)) r += "@{:10} (+{:10}) ".format(str(time), str(time-prev_time))
for item in description: for item in description:
r += "{:16}".format(str(item)) r += "{:16}".format(str(item))
r += "\n" r += "\n"
prev_time = time prev_time = time
return r return r
manager = Manager() manager = Manager()
core_language.set_time_manager(manager) core_language.set_time_manager(manager)

View File

@ -1,51 +1,54 @@
from artiq.language.units import * from artiq.language.units import *
from artiq.language.core import * from artiq.language.core import *
class AluminumSpectroscopy(AutoContext):
parameters = "mains_sync laser_cooling spectroscopy spectroscopy_b state_detection pmt \
spectroscopy_freq photon_limit_low photon_limit_high"
@kernel class AluminumSpectroscopy(AutoContext):
def run(self): parameters = "mains_sync laser_cooling spectroscopy spectroscopy_b state_detection pmt \
state_0_count = 0 spectroscopy_freq photon_limit_low photon_limit_high"
for count in range(100):
self.mains_sync.wait_edge() @kernel
delay(10*us) def run(self):
self.laser_cooling.pulse(100*MHz, 100*us) state_0_count = 0
delay(5*us) for count in range(100):
with parallel: self.mains_sync.wait_edge()
self.spectroscopy.pulse(self.spectroscopy_freq, 100*us) delay(10*us)
with sequential: self.laser_cooling.pulse(100*MHz, 100*us)
delay(50*us) delay(5*us)
self.spectroscopy_b.set(200) with parallel:
delay(5*us) self.spectroscopy.pulse(self.spectroscopy_freq, 100*us)
while True: with sequential:
delay(5*us) delay(50*us)
with parallel: self.spectroscopy_b.set(200)
self.state_detection.pulse(100*MHz, 10*us) delay(5*us)
photon_count = self.pmt.count_gate(10*us) while True:
if photon_count < self.photon_limit_low or photon_count > self.photon_limit_high: delay(5*us)
break with parallel:
if photon_count < self.photon_limit_low: self.state_detection.pulse(100*MHz, 10*us)
state_0_count += 1 photon_count = self.pmt.count_gate(10*us)
return state_0_count if (photon_count < self.photon_limit_low
or photon_count > self.photon_limit_high):
break
if photon_count < self.photon_limit_low:
state_0_count += 1
return state_0_count
if __name__ == "__main__": if __name__ == "__main__":
from artiq.sim import devices as sd from artiq.sim import devices as sd
from artiq.sim import time from artiq.sim import time
exp = AluminumSpectroscopy( exp = AluminumSpectroscopy(
core=sd.Core(), core=sd.Core(),
mains_sync=sd.Input(name="mains_sync"), mains_sync=sd.Input(name="mains_sync"),
laser_cooling=sd.WaveOutput(name="laser_cooling"), laser_cooling=sd.WaveOutput(name="laser_cooling"),
spectroscopy=sd.WaveOutput(name="spectroscopy"), spectroscopy=sd.WaveOutput(name="spectroscopy"),
spectroscopy_b=sd.VoltageOutput(name="spectroscopy_b"), spectroscopy_b=sd.VoltageOutput(name="spectroscopy_b"),
state_detection=sd.WaveOutput(name="state_detection"), state_detection=sd.WaveOutput(name="state_detection"),
pmt=sd.Input(name="pmt"), pmt=sd.Input(name="pmt"),
spectroscopy_freq=432*MHz, spectroscopy_freq=432*MHz,
photon_limit_low=10, photon_limit_low=10,
photon_limit_high=15 photon_limit_high=15
) )
exp.run() exp.run()
print(time.manager.format_timeline()) print(time.manager.format_timeline())

View File

@ -1,41 +1,48 @@
from artiq.language.units import * from artiq.language.units import *
from artiq.language.core import * from artiq.language.core import *
my_range = range my_range = range
class CompilerTest(AutoContext): class CompilerTest(AutoContext):
parameters = "a b A B" parameters = "a b A B"
def print_done(self): def print_done(self):
print("Done!") print("Done!")
def set_some_slowdev(self, n): def set_some_slowdev(self, n):
print("Slow device setting: {}".format(n)) print("Slow device setting: {}".format(n))
@kernel
def run(self, n, t2):
for i in my_range(n):
self.set_some_slowdev(i)
delay(100*ms)
with parallel:
with sequential:
for j in my_range(3):
self.a.pulse((j+1)*100*MHz, 20*us)
self.b.pulse(100*MHz, t2)
with sequential:
self.A.pulse(100*MHz, 10*us)
self.B.pulse(100*MHz, t2)
self.print_done()
@kernel
def run(self, n, t2):
for i in my_range(n):
self.set_some_slowdev(i)
delay(100*ms)
with parallel:
with sequential:
for j in my_range(3):
self.a.pulse((j+1)*100*MHz, 20*us)
self.b.pulse(100*MHz, t2)
with sequential:
self.A.pulse(100*MHz, 10*us)
self.B.pulse(100*MHz, t2)
self.print_done()
if __name__ == "__main__": if __name__ == "__main__":
from artiq.devices import corecom_dummy, core, dds_core from artiq.devices import corecom_dummy, core, dds_core
coredev = core.Core(corecom_dummy.CoreCom()) coredev = core.Core(corecom_dummy.CoreCom())
exp = CompilerTest( exp = CompilerTest(
core=coredev, core=coredev,
a=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=0, rtio_channel=0), a=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
b=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=1, rtio_channel=1), reg_channel=0, rtio_channel=0),
A=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=2, rtio_channel=2), b=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
B=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=3, rtio_channel=3) reg_channel=1, rtio_channel=1),
) A=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
exp.run(3, 100*us) reg_channel=2, rtio_channel=2),
B=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
reg_channel=3, rtio_channel=3)
)
exp.run(3, 100*us)

View File

@ -1,37 +1,39 @@
from artiq.language.core import AutoContext, kernel from artiq.language.core import AutoContext, kernel
from artiq.devices import corecom_serial, core, gpio_core from artiq.devices import corecom_serial, core, gpio_core
class CompilerTest(AutoContext): class CompilerTest(AutoContext):
parameters = "led" parameters = "led"
def output(self, n): def output(self, n):
print("Received: "+str(n)) print("Received: "+str(n))
def get_max(self): def get_max(self):
return int(input("Maximum: ")) return int(input("Maximum: "))
@kernel
def run(self):
self.led.set(1)
x = 1
m = self.get_max()
while x < m:
d = 2
prime = True
while d*d <= x:
if x % d == 0:
prime = False
d += 1
if prime:
self.output(x)
x += 1
self.led.set(0)
@kernel
def run(self):
self.led.set(1)
x = 1
m = self.get_max()
while x < m:
d = 2
prime = True
while d*d <= x:
if x % d == 0:
prime = False
d += 1
if prime:
self.output(x)
x += 1
self.led.set(0)
if __name__ == "__main__": if __name__ == "__main__":
with corecom_serial.CoreCom() as com: with corecom_serial.CoreCom() as com:
coredev = core.Core(com) coredev = core.Core(com)
exp = CompilerTest( exp = CompilerTest(
core=coredev, core=coredev,
led=gpio_core.GPIOOut(core=coredev, channel=0) led=gpio_core.GPIOOut(core=coredev, channel=0)
) )
exp.run() exp.run()

View File

@ -2,36 +2,42 @@ from artiq.language.units import *
from artiq.language.core import * from artiq.language.core import *
from artiq.devices import corecom_serial, core, dds_core, gpio_core from artiq.devices import corecom_serial, core, dds_core, gpio_core
class DDSTest(AutoContext):
parameters = "a b c d led"
@kernel class DDSTest(AutoContext):
def run(self): parameters = "a b c d led"
i = 0
while i < 10000: @kernel
if i & 0x200: def run(self):
self.led.set(1) i = 0
else: while i < 10000:
self.led.set(0) if i & 0x200:
with parallel: self.led.set(1)
with sequential: else:
self.a.pulse(100*MHz + 4*i*kHz, 500*us) self.led.set(0)
self.b.pulse(120*MHz, 500*us) with parallel:
with sequential: with sequential:
self.c.pulse(200*MHz, 100*us) self.a.pulse(100*MHz + 4*i*kHz, 500*us)
self.d.pulse(250*MHz, 200*us) self.b.pulse(120*MHz, 500*us)
i += 1 with sequential:
self.led.set(0) self.c.pulse(200*MHz, 100*us)
self.d.pulse(250*MHz, 200*us)
i += 1
self.led.set(0)
if __name__ == "__main__": if __name__ == "__main__":
with corecom_serial.CoreCom() as com: with corecom_serial.CoreCom() as com:
coredev = core.Core(com) coredev = core.Core(com)
exp = DDSTest( exp = DDSTest(
core=coredev, core=coredev,
a=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=0, rtio_channel=0), a=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
b=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=1, rtio_channel=1), reg_channel=0, rtio_channel=0),
c=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=2, rtio_channel=2), b=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
d=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=3, rtio_channel=3), reg_channel=1, rtio_channel=1),
led=gpio_core.GPIOOut(core=coredev, channel=1) c=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
) reg_channel=2, rtio_channel=2),
exp.run() d=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
reg_channel=3, rtio_channel=3),
led=gpio_core.GPIOOut(core=coredev, channel=1)
)
exp.run()

View File

@ -1,29 +1,31 @@
from artiq.language.units import * from artiq.language.units import *
from artiq.language.core import * from artiq.language.core import *
class SimpleSimulation(AutoContext):
parameters = "a b c d"
@kernel class SimpleSimulation(AutoContext):
def run(self): parameters = "a b c d"
with parallel:
with sequential: @kernel
self.a.pulse(100*MHz, 20*us) def run(self):
self.b.pulse(200*MHz, 20*us) with parallel:
with sequential: with sequential:
self.c.pulse(300*MHz, 10*us) self.a.pulse(100*MHz, 20*us)
self.d.pulse(400*MHz, 20*us) self.b.pulse(200*MHz, 20*us)
with sequential:
self.c.pulse(300*MHz, 10*us)
self.d.pulse(400*MHz, 20*us)
if __name__ == "__main__": if __name__ == "__main__":
from artiq.sim import devices as sd from artiq.sim import devices as sd
from artiq.sim import time from artiq.sim import time
exp = SimpleSimulation( exp = SimpleSimulation(
core=sd.Core(), core=sd.Core(),
a=sd.WaveOutput(name="a"), a=sd.WaveOutput(name="a"),
b=sd.WaveOutput(name="b"), b=sd.WaveOutput(name="b"),
c=sd.WaveOutput(name="c"), c=sd.WaveOutput(name="c"),
d=sd.WaveOutput(name="d"), d=sd.WaveOutput(name="d"),
) )
exp.run() exp.run()
print(time.manager.format_timeline()) print(time.manager.format_timeline())

View File

@ -2,45 +2,48 @@ from artiq.language.units import *
from artiq.language.core import * from artiq.language.core import *
from artiq.devices import corecom_serial, core from artiq.devices import corecom_serial, core
class DummyPulse(AutoContext): class DummyPulse(AutoContext):
parameters = "name" parameters = "name"
def print_on(self, t, f): def print_on(self, t, f):
print("{} ON:{:4} @{}".format(self.name, f, t)) print("{} ON:{:4} @{}".format(self.name, f, t))
def print_off(self, t): def print_off(self, t):
print("{} OFF @{}".format(self.name, t)) print("{} OFF @{}".format(self.name, t))
@kernel
def pulse(self, f, duration):
self.print_on(int(now()), f)
delay(duration)
self.print_off(int(now()))
@kernel
def pulse(self, f, duration):
self.print_on(int(now()), f)
delay(duration)
self.print_off(int(now()))
class TimeTest(AutoContext): class TimeTest(AutoContext):
parameters = "a b c d" parameters = "a b c d"
@kernel
def run(self):
i = 0
while i < 3:
with parallel:
with sequential:
self.a.pulse(100+i, 20*us)
self.b.pulse(200+i, 20*us)
with sequential:
self.c.pulse(300+i, 10*us)
self.d.pulse(400+i, 20*us)
i += 1
@kernel
def run(self):
i = 0
while i < 3:
with parallel:
with sequential:
self.a.pulse(100+i, 20*us)
self.b.pulse(200+i, 20*us)
with sequential:
self.c.pulse(300+i, 10*us)
self.d.pulse(400+i, 20*us)
i += 1
if __name__ == "__main__": if __name__ == "__main__":
with corecom_serial.CoreCom() as com: with corecom_serial.CoreCom() as com:
coredev = core.Core(com) coredev = core.Core(com)
exp = TimeTest( exp = TimeTest(
core=coredev, core=coredev,
a=DummyPulse(core=coredev, name="a"), a=DummyPulse(core=coredev, name="a"),
b=DummyPulse(core=coredev, name="b"), b=DummyPulse(core=coredev, name="b"),
c=DummyPulse(core=coredev, name="c"), c=DummyPulse(core=coredev, name="c"),
d=DummyPulse(core=coredev, name="d"), d=DummyPulse(core=coredev, name="d"),
) )
exp.run() exp.run()

View File

@ -4,192 +4,197 @@ from migen.bus import wishbone
from migen.bus.transactions import * from migen.bus.transactions import *
from migen.sim.generic import run_simulation from migen.sim.generic import run_simulation
class AD9858(Module): class AD9858(Module):
"""Wishbone interface to the AD9858 DDS chip. """Wishbone interface to the AD9858 DDS chip.
Addresses 0-63 map the AD9858 registers. Addresses 0-63 map the AD9858 registers.
Data is zero-padded. Data is zero-padded.
Write to address 64 to pulse the FUD signal. Write to address 64 to pulse the FUD signal.
Address 65 is a GPIO register that controls the sel, p and reset signals. Address 65 is a GPIO register that controls the sel, p and reset signals.
sel is mapped to the lower bits, followed by p and reset. sel is mapped to the lower bits, followed by p and reset.
Write timing: Write timing:
Address is set one cycle before assertion of we_n. Address is set one cycle before assertion of we_n.
we_n is asserted for one cycle, at the same time as valid data is driven. we_n is asserted for one cycle, at the same time as valid data is driven.
Read timing: Read timing:
Address is set one cycle before assertion of rd_n. Address is set one cycle before assertion of rd_n.
rd_n is asserted for 3 cycles. rd_n is asserted for 3 cycles.
Data is sampled 2 cycles into the assertion of rd_n. Data is sampled 2 cycles into the assertion of rd_n.
Design: Design:
All IO pads are registered. All IO pads are registered.
LVDS driver/receiver propagation delays are 3.6+4.5 ns max LVDS driver/receiver propagation delays are 3.6+4.5 ns max
LVDS state transition delays are 20, 15 ns max LVDS state transition delays are 20, 15 ns max
Schmitt trigger delays are 6.4ns max Schmitt trigger delays are 6.4ns max
Round-trip addr A setup (> RX, RD, D to Z), RD prop, D valid (< D Round-trip addr A setup (> RX, RD, D to Z), RD prop, D valid (< D
valid), D prop is ~15 + 10 + 20 + 10 = 55ns valid), D prop is ~15 + 10 + 20 + 10 = 55ns
""" """
def __init__(self, pads, bus=None): def __init__(self, pads, bus=None):
if bus is None: if bus is None:
bus = wishbone.Interface() bus = wishbone.Interface()
self.bus = bus self.bus = bus
### # # #
dts = TSTriple(8) dts = TSTriple(8)
self.specials += dts.get_tristate(pads.d) self.specials += dts.get_tristate(pads.d)
dr = Signal(8) dr = Signal(8)
rx = Signal() rx = Signal()
self.sync += [ self.sync += [
pads.a.eq(bus.adr), pads.a.eq(bus.adr),
dts.o.eq(bus.dat_w), dts.o.eq(bus.dat_w),
dr.eq(dts.i), dr.eq(dts.i),
dts.oe.eq(~rx) dts.oe.eq(~rx)
] ]
gpio = Signal(flen(pads.sel) + flen(pads.p) + 1) gpio = Signal(flen(pads.sel) + flen(pads.p) + 1)
gpio_load = Signal() gpio_load = Signal()
self.sync += If(gpio_load, gpio.eq(bus.dat_w)) self.sync += If(gpio_load, gpio.eq(bus.dat_w))
self.comb += [ self.comb += [
Cat(pads.sel, pads.p).eq(gpio), Cat(pads.sel, pads.p).eq(gpio),
pads.rst_n.eq(~gpio[-1]), pads.rst_n.eq(~gpio[-1]),
] ]
bus_r_gpio = Signal() bus_r_gpio = Signal()
self.comb += If(bus_r_gpio, self.comb += If(bus_r_gpio,
bus.dat_r.eq(gpio) bus.dat_r.eq(gpio)
).Else( ).Else(
bus.dat_r.eq(dr) bus.dat_r.eq(dr)
) )
fud = Signal() fud = Signal()
self.sync += pads.fud_n.eq(~fud) self.sync += pads.fud_n.eq(~fud)
pads.wr_n.reset = 1 pads.wr_n.reset = 1
pads.rd_n.reset = 1 pads.rd_n.reset = 1
wr = Signal() wr = Signal()
rd = Signal() rd = Signal()
self.sync += pads.wr_n.eq(~wr), pads.rd_n.eq(~rd) self.sync += pads.wr_n.eq(~wr), pads.rd_n.eq(~rd)
fsm = FSM("IDLE") fsm = FSM("IDLE")
self.submodules += fsm self.submodules += fsm
fsm.act("IDLE",
If(bus.cyc & bus.stb,
If(bus.adr[6],
If(bus.adr[0],
NextState("GPIO")
).Else(
NextState("FUD")
)
).Else(
If(bus.we,
NextState("WRITE")
).Else(
NextState("READ")
)
)
)
)
fsm.act("WRITE",
# 3ns A setup to WR active
wr.eq(1),
NextState("WRITE0")
)
fsm.act("WRITE0",
# 3.5ns D setup to WR inactive
# 0ns D and A hold to WR inactive
bus.ack.eq(1),
NextState("IDLE")
)
fsm.act("READ",
# 15ns D valid to A setup
# 15ns D valid to RD active
rx.eq(1),
rd.eq(1),
NextState("READ0")
)
fsm.act("READ0",
rx.eq(1),
rd.eq(1),
NextState("READ1")
)
fsm.act("READ1",
rx.eq(1),
rd.eq(1),
NextState("READ2")
)
fsm.act("READ2",
rx.eq(1),
rd.eq(1),
NextState("READ3")
)
fsm.act("READ3",
rx.eq(1),
rd.eq(1),
NextState("READ4")
)
fsm.act("READ4",
rx.eq(1),
NextState("READ5")
)
fsm.act("READ5",
# 5ns D three-state to RD inactive
# 10ns A hold to RD inactive
rx.eq(1),
bus.ack.eq(1),
NextState("IDLE")
)
fsm.act("GPIO",
bus.ack.eq(1),
bus_r_gpio.eq(1),
If(bus.we, gpio_load.eq(1)),
NextState("IDLE")
)
fsm.act("FUD",
# 4ns FUD setup to SYNCLK
# 0ns FUD hold to SYNCLK
fud.eq(1),
bus.ack.eq(1),
NextState("IDLE")
)
fsm.act("IDLE",
If(bus.cyc & bus.stb,
If(bus.adr[6],
If(bus.adr[0],
NextState("GPIO")
).Else(
NextState("FUD")
)
).Else(
If(bus.we,
NextState("WRITE")
).Else(
NextState("READ")
)
)
)
)
fsm.act("WRITE",
# 3ns A setup to WR active
wr.eq(1),
NextState("WRITE0")
)
fsm.act("WRITE0",
# 3.5ns D setup to WR inactive
# 0ns D and A hold to WR inactive
bus.ack.eq(1),
NextState("IDLE")
)
fsm.act("READ",
# 15ns D valid to A setup
# 15ns D valid to RD active
rx.eq(1),
rd.eq(1),
NextState("READ0")
)
fsm.act("READ0",
rx.eq(1),
rd.eq(1),
NextState("READ1")
)
fsm.act("READ1",
rx.eq(1),
rd.eq(1),
NextState("READ2")
)
fsm.act("READ2",
rx.eq(1),
rd.eq(1),
NextState("READ3")
)
fsm.act("READ3",
rx.eq(1),
rd.eq(1),
NextState("READ4")
)
fsm.act("READ4",
rx.eq(1),
NextState("READ5")
)
fsm.act("READ5",
# 5ns D three-state to RD inactive
# 10ns A hold to RD inactive
rx.eq(1),
bus.ack.eq(1),
NextState("IDLE")
)
fsm.act("GPIO",
bus.ack.eq(1),
bus_r_gpio.eq(1),
If(bus.we, gpio_load.eq(1)),
NextState("IDLE")
)
fsm.act("FUD",
# 4ns FUD setup to SYNCLK
# 0ns FUD hold to SYNCLK
fud.eq(1),
bus.ack.eq(1),
NextState("IDLE")
)
def _test_gen(): def _test_gen():
# Test external bus writes # Test external bus writes
yield TWrite(4, 2) yield TWrite(4, 2)
yield TWrite(5, 3) yield TWrite(5, 3)
yield yield
# Test external bus reads # Test external bus reads
yield TRead(14) yield TRead(14)
yield TRead(15) yield TRead(15)
yield yield
# Test FUD # Test FUD
yield TWrite(64, 0) yield TWrite(64, 0)
yield yield
# Test GPIO # Test GPIO
yield TWrite(65, 0xff) yield TWrite(65, 0xff)
yield yield
class _TestPads: class _TestPads:
def __init__(self): def __init__(self):
self.a = Signal(6) self.a = Signal(6)
self.d = Signal(8) self.d = Signal(8)
self.sel = Signal(5) self.sel = Signal(5)
self.p = Signal(2) self.p = Signal(2)
self.fud_n = Signal() self.fud_n = Signal()
self.wr_n = Signal() self.wr_n = Signal()
self.rd_n = Signal() self.rd_n = Signal()
self.rst_n = Signal() self.rst_n = Signal()
class _TB(Module): class _TB(Module):
def __init__(self): def __init__(self):
pads = _TestPads() pads = _TestPads()
self.submodules.dut = AD9858(pads) self.submodules.dut = AD9858(pads)
self.submodules.initiator = wishbone.Initiator(_test_gen()) self.submodules.initiator = wishbone.Initiator(_test_gen())
self.submodules.interconnect = wishbone.InterconnectPointToPoint(self.initiator.bus, self.dut.bus) self.submodules.interconnect = wishbone.InterconnectPointToPoint(self.initiator.bus, self.dut.bus)
if __name__ == "__main__": if __name__ == "__main__":
run_simulation(_TB(), vcd_name="ad9858.vcd") run_simulation(_TB(), vcd_name="ad9858.vcd")

View File

@ -5,179 +5,182 @@ from migen.genlib.cdc import MultiReg
from artiqlib.rtio.rbus import get_fine_ts_width from artiqlib.rtio.rbus import get_fine_ts_width
class _RTIOBankO(Module): class _RTIOBankO(Module):
def __init__(self, rbus, counter_width, fine_ts_width, fifo_depth, counter_init): def __init__(self, rbus, counter_width, fine_ts_width, fifo_depth, counter_init):
self.sel = Signal(max=len(rbus)) self.sel = Signal(max=len(rbus))
self.timestamp = Signal(counter_width+fine_ts_width) self.timestamp = Signal(counter_width+fine_ts_width)
self.value = Signal(2) self.value = Signal(2)
self.writable = Signal() self.writable = Signal()
self.we = Signal() self.we = Signal()
self.underflow = Signal() self.underflow = Signal()
self.level = Signal(bits_for(fifo_depth)) self.level = Signal(bits_for(fifo_depth))
### # # #
counter = Signal(counter_width, reset=counter_init) counter = Signal(counter_width, reset=counter_init)
self.sync += [ self.sync += [
counter.eq(counter + 1), counter.eq(counter + 1),
If(self.we & self.writable, If(self.we & self.writable,
If(self.timestamp[fine_ts_width:] < counter + 2, self.underflow.eq(1)) If(self.timestamp[fine_ts_width:] < counter + 2, self.underflow.eq(1))
) )
] ]
fifos = [] fifos = []
for n, chif in enumerate(rbus): for n, chif in enumerate(rbus):
fifo = SyncFIFOBuffered([ fifo = SyncFIFOBuffered([
("timestamp", counter_width+fine_ts_width), ("value", 2)], ("timestamp", counter_width+fine_ts_width), ("value", 2)],
fifo_depth) fifo_depth)
self.submodules += fifo self.submodules += fifo
fifos.append(fifo) fifos.append(fifo)
# FIFO write # FIFO write
self.comb += [ self.comb += [
fifo.din.timestamp.eq(self.timestamp), fifo.din.timestamp.eq(self.timestamp),
fifo.din.value.eq(self.value), fifo.din.value.eq(self.value),
fifo.we.eq(self.we & (self.sel == n)) fifo.we.eq(self.we & (self.sel == n))
] ]
# FIFO read # FIFO read
self.comb += [ self.comb += [
chif.o_stb.eq(fifo.readable & chif.o_stb.eq(fifo.readable &
(fifo.dout.timestamp[fine_ts_width:] == counter)), (fifo.dout.timestamp[fine_ts_width:] == counter)),
chif.o_value.eq(fifo.dout.value), chif.o_value.eq(fifo.dout.value),
fifo.re.eq(chif.o_stb) fifo.re.eq(chif.o_stb)
] ]
if fine_ts_width: if fine_ts_width:
self.comb += chif.o_fine_ts.eq(fifo.dout.timestamp[:fine_ts_width]) self.comb += chif.o_fine_ts.eq(fifo.dout.timestamp[:fine_ts_width])
selfifo = Array(fifos)[self.sel]
self.comb += self.writable.eq(selfifo.writable), self.level.eq(selfifo.level)
selfifo = Array(fifos)[self.sel]
self.comb += self.writable.eq(selfifo.writable), self.level.eq(selfifo.level)
class _RTIOBankI(Module): class _RTIOBankI(Module):
def __init__(self, rbus, counter_width, fine_ts_width, fifo_depth): def __init__(self, rbus, counter_width, fine_ts_width, fifo_depth):
self.sel = Signal(max=len(rbus)) self.sel = Signal(max=len(rbus))
self.timestamp = Signal(counter_width+fine_ts_width) self.timestamp = Signal(counter_width+fine_ts_width)
self.value = Signal() self.value = Signal()
self.readable = Signal() self.readable = Signal()
self.re = Signal() self.re = Signal()
self.overflow = Signal() self.overflow = Signal()
### ###
counter = Signal(counter_width) counter = Signal(counter_width)
self.sync += counter.eq(counter + 1) self.sync += counter.eq(counter + 1)
timestamps = [] timestamps = []
values = [] values = []
readables = [] readables = []
overflows = [] overflows = []
for n, chif in enumerate(rbus): for n, chif in enumerate(rbus):
if hasattr(chif, "oe"): if hasattr(chif, "oe"):
sensitivity = Signal(2) sensitivity = Signal(2)
self.sync += If(~chif.oe & chif.o_stb, self.sync += If(~chif.oe & chif.o_stb,
sensitivity.eq(chif.o_value)) sensitivity.eq(chif.o_value))
fifo = SyncFIFOBuffered([ fifo = SyncFIFOBuffered([
("timestamp", counter_width+fine_ts_width), ("value", 1)], ("timestamp", counter_width+fine_ts_width), ("value", 1)],
fifo_depth) fifo_depth)
self.submodules += fifo self.submodules += fifo
# FIFO write # FIFO write
if fine_ts_width: if fine_ts_width:
full_ts = Cat(chif.i_fine_ts, counter) full_ts = Cat(chif.i_fine_ts, counter)
else: else:
full_ts = counter full_ts = counter
self.comb += [ self.comb += [
fifo.din.timestamp.eq(full_ts), fifo.din.timestamp.eq(full_ts),
fifo.din.value.eq(chif.i_value), fifo.din.value.eq(chif.i_value),
fifo.we.eq(~chif.oe & chif.i_stb & fifo.we.eq(~chif.oe & chif.i_stb &
((chif.i_value & sensitivity[0]) | (~chif.i_value & sensitivity[1]))) ((chif.i_value & sensitivity[0]) | (~chif.i_value & sensitivity[1])))
] ]
# FIFO read # FIFO read
timestamps.append(fifo.dout.timestamp) timestamps.append(fifo.dout.timestamp)
values.append(fifo.dout.value) values.append(fifo.dout.value)
readables.append(fifo.readable) readables.append(fifo.readable)
self.comb += fifo.re.eq(self.re & (self.sel == n)) self.comb += fifo.re.eq(self.re & (self.sel == n))
overflow = Signal() overflow = Signal()
self.sync += If(fifo.we & ~fifo.writable, overflow.eq(1)) self.sync += If(fifo.we & ~fifo.writable, overflow.eq(1))
overflows.append(overflow) overflows.append(overflow)
else: else:
timestamps.append(0) timestamps.append(0)
values.append(0) values.append(0)
readables.append(0) readables.append(0)
overflows.append(0) overflows.append(0)
self.comb += [
self.timestamp.eq(Array(timestamps)[self.sel]),
self.value.eq(Array(values)[self.sel]),
self.readable.eq(Array(readables)[self.sel]),
self.overflow.eq(Array(overflows)[self.sel])
]
self.comb += [
self.timestamp.eq(Array(timestamps)[self.sel]),
self.value.eq(Array(values)[self.sel]),
self.readable.eq(Array(readables)[self.sel]),
self.overflow.eq(Array(overflows)[self.sel])
]
class RTIO(Module, AutoCSR): class RTIO(Module, AutoCSR):
def __init__(self, phy, counter_width=32, ofifo_depth=8, ififo_depth=8): def __init__(self, phy, counter_width=32, ofifo_depth=8, ififo_depth=8):
fine_ts_width = get_fine_ts_width(phy.rbus) fine_ts_width = get_fine_ts_width(phy.rbus)
# Submodules # Submodules
self.submodules.bank_o = InsertReset(_RTIOBankO(phy.rbus, self.submodules.bank_o = InsertReset(_RTIOBankO(phy.rbus,
counter_width, fine_ts_width, ofifo_depth, counter_width, fine_ts_width, ofifo_depth,
phy.loopback_latency)) phy.loopback_latency))
self.submodules.bank_i = InsertReset(_RTIOBankI(phy.rbus, self.submodules.bank_i = InsertReset(_RTIOBankI(phy.rbus,
counter_width, fine_ts_width, ofifo_depth)) counter_width, fine_ts_width, ofifo_depth))
# CSRs # CSRs
self._r_reset = CSRStorage(reset=1) self._r_reset = CSRStorage(reset=1)
self._r_chan_sel = CSRStorage(flen(self.bank_o.sel)) self._r_chan_sel = CSRStorage(flen(self.bank_o.sel))
self._r_oe = CSR() self._r_oe = CSR()
self._r_o_timestamp = CSRStorage(counter_width+fine_ts_width) self._r_o_timestamp = CSRStorage(counter_width+fine_ts_width)
self._r_o_value = CSRStorage(2) self._r_o_value = CSRStorage(2)
self._r_o_writable = CSRStatus() self._r_o_writable = CSRStatus()
self._r_o_we = CSR() self._r_o_we = CSR()
self._r_o_underflow = CSRStatus() self._r_o_underflow = CSRStatus()
self._r_o_level = CSRStatus(bits_for(ofifo_depth)) self._r_o_level = CSRStatus(bits_for(ofifo_depth))
self._r_i_timestamp = CSRStatus(counter_width+fine_ts_width) self._r_i_timestamp = CSRStatus(counter_width+fine_ts_width)
self._r_i_value = CSRStatus() self._r_i_value = CSRStatus()
self._r_i_readable = CSRStatus() self._r_i_readable = CSRStatus()
self._r_i_re = CSR() self._r_i_re = CSR()
self._r_i_overflow = CSRStatus() self._r_i_overflow = CSRStatus()
# OE # OE
oes = [] oes = []
for n, chif in enumerate(phy.rbus): for n, chif in enumerate(phy.rbus):
if hasattr(chif, "oe"): if hasattr(chif, "oe"):
self.sync += \ self.sync += \
If(self._r_oe.re & (self._r_chan_sel.storage == n), If(self._r_oe.re & (self._r_chan_sel.storage == n),
chif.oe.eq(self._r_oe.r) chif.oe.eq(self._r_oe.r)
) )
oes.append(chif.oe) oes.append(chif.oe)
else: else:
oes.append(1) oes.append(1)
self.comb += self._r_oe.w.eq(Array(oes)[self._r_chan_sel.storage]) self.comb += self._r_oe.w.eq(Array(oes)[self._r_chan_sel.storage])
# Output/Gate # Output/Gate
self.comb += [ self.comb += [
self.bank_o.reset.eq(self._r_reset.storage), self.bank_o.reset.eq(self._r_reset.storage),
self.bank_o.sel.eq(self._r_chan_sel.storage), self.bank_o.sel.eq(self._r_chan_sel.storage),
self.bank_o.timestamp.eq(self._r_o_timestamp.storage), self.bank_o.timestamp.eq(self._r_o_timestamp.storage),
self.bank_o.value.eq(self._r_o_value.storage), self.bank_o.value.eq(self._r_o_value.storage),
self._r_o_writable.status.eq(self.bank_o.writable), self._r_o_writable.status.eq(self.bank_o.writable),
self.bank_o.we.eq(self._r_o_we.re), self.bank_o.we.eq(self._r_o_we.re),
self._r_o_underflow.status.eq(self.bank_o.underflow), self._r_o_underflow.status.eq(self.bank_o.underflow),
self._r_o_level.status.eq(self.bank_o.level) self._r_o_level.status.eq(self.bank_o.level)
] ]
# Input # Input
self.comb += [ self.comb += [
self.bank_i.reset.eq(self._r_reset.storage), self.bank_i.reset.eq(self._r_reset.storage),
self.bank_i.sel.eq(self._r_chan_sel.storage), self.bank_i.sel.eq(self._r_chan_sel.storage),
self._r_i_timestamp.status.eq(self.bank_i.timestamp), self._r_i_timestamp.status.eq(self.bank_i.timestamp),
self._r_i_value.status.eq(self.bank_i.value), self._r_i_value.status.eq(self.bank_i.value),
self._r_i_readable.status.eq(self.bank_i.readable), self._r_i_readable.status.eq(self.bank_i.readable),
self.bank_i.re.eq(self._r_i_re.re), self.bank_i.re.eq(self._r_i_re.re),
self._r_i_overflow.status.eq(self.bank_i.overflow) self._r_i_overflow.status.eq(self.bank_i.overflow)
] ]

View File

@ -3,27 +3,28 @@ from migen.genlib.cdc import MultiReg
from artiqlib.rtio.rbus import create_rbus from artiqlib.rtio.rbus import create_rbus
class SimplePHY(Module): class SimplePHY(Module):
def __init__(self, pads, output_only_pads=set()): def __init__(self, pads, output_only_pads=set()):
self.rbus = create_rbus(0, pads, output_only_pads) self.rbus = create_rbus(0, pads, output_only_pads)
self.loopback_latency = 3 self.loopback_latency = 3
### # # #
for pad, chif in zip(pads, self.rbus): for pad, chif in zip(pads, self.rbus):
o_pad = Signal() o_pad = Signal()
self.sync += If(chif.o_stb, o_pad.eq(chif.o_value)) self.sync += If(chif.o_stb, o_pad.eq(chif.o_value))
if pad in output_only_pads: if pad in output_only_pads:
self.comb += pad.eq(o_pad) self.comb += pad.eq(o_pad)
else: else:
ts = TSTriple() ts = TSTriple()
i_pad = Signal() i_pad = Signal()
self.sync += ts.oe.eq(chif.oe) self.sync += ts.oe.eq(chif.oe)
self.comb += ts.o.eq(o_pad) self.comb += ts.o.eq(o_pad)
self.specials += MultiReg(ts.i, i_pad), \ self.specials += MultiReg(ts.i, i_pad), \
ts.get_tristate(pad) ts.get_tristate(pad)
i_pad_d = Signal() i_pad_d = Signal()
self.sync += i_pad_d.eq(i_pad) self.sync += i_pad_d.eq(i_pad)
self.comb += chif.i_stb.eq(i_pad ^ i_pad_d), \ self.comb += chif.i_stb.eq(i_pad ^ i_pad_d), \
chif.i_value.eq(i_pad) chif.i_value.eq(i_pad)

View File

@ -2,27 +2,27 @@ from migen.fhdl.std import *
from migen.genlib.record import Record from migen.genlib.record import Record
def create_rbus(fine_ts_bits, pads, output_only_pads): def create_rbus(fine_ts_bits, pads, output_only_pads):
rbus = [] rbus = []
for pad in pads: for pad in pads:
layout = [ layout = [
("o_stb", 1), ("o_stb", 1),
("o_value", 2) ("o_value", 2)
] ]
if fine_ts_bits: if fine_ts_bits:
layout.append(("o_fine_ts", fine_ts_bits)) layout.append(("o_fine_ts", fine_ts_bits))
if pad not in output_only_pads: if pad not in output_only_pads:
layout += [ layout += [
("oe", 1), ("oe", 1),
("i_stb", 1), ("i_stb", 1),
("i_value", 1) ("i_value", 1)
] ]
if fine_ts_bits: if fine_ts_bits:
layout.append(("i_fine_ts", fine_ts_bits)) layout.append(("i_fine_ts", fine_ts_bits))
rbus.append(Record(layout)) rbus.append(Record(layout))
return rbus return rbus
def get_fine_ts_width(rbus): def get_fine_ts_width(rbus):
if hasattr(rbus[0], "o_fine_ts"): if hasattr(rbus[0], "o_fine_ts"):
return flen(rbus[0].o_fine_ts) return flen(rbus[0].o_fine_ts)
else: else:
return 0 return 0

View File

@ -6,123 +6,123 @@
#include "corecom.h" #include "corecom.h"
enum { enum {
MSGTYPE_REQUEST_IDENT = 0x01, MSGTYPE_REQUEST_IDENT = 0x01,
MSGTYPE_LOAD_KERNEL = 0x02, MSGTYPE_LOAD_KERNEL = 0x02,
MSGTYPE_KERNEL_FINISHED = 0x03, MSGTYPE_KERNEL_FINISHED = 0x03,
MSGTYPE_RPC_REQUEST = 0x04, MSGTYPE_RPC_REQUEST = 0x04,
}; };
static int receive_int(void) static int receive_int(void)
{ {
unsigned int r; unsigned int r;
int i; int i;
r = 0; r = 0;
for(i=0;i<4;i++) { for(i=0;i<4;i++) {
r <<= 8; r <<= 8;
r |= (unsigned char)uart_read(); r |= (unsigned char)uart_read();
} }
return r; return r;
} }
static char receive_char(void) static char receive_char(void)
{ {
return uart_read(); return uart_read();
} }
static void send_int(int x) static void send_int(int x)
{ {
int i; int i;
for(i=0;i<4;i++) { for(i=0;i<4;i++) {
uart_write((x & 0xff000000) >> 24); uart_write((x & 0xff000000) >> 24);
x <<= 8; x <<= 8;
} }
} }
static void send_sint(short int i) static void send_sint(short int i)
{ {
uart_write((i >> 8) & 0xff); uart_write((i >> 8) & 0xff);
uart_write(i & 0xff); uart_write(i & 0xff);
} }
static void send_char(char c) static void send_char(char c)
{ {
uart_write(c); uart_write(c);
} }
static void receive_sync(void) static void receive_sync(void)
{ {
char c; char c;
int recognized; int recognized;
recognized = 0; recognized = 0;
while(recognized < 4) { while(recognized < 4) {
c = uart_read(); c = uart_read();
if(c == 0x5a) if(c == 0x5a)
recognized++; recognized++;
else else
recognized = 0; recognized = 0;
} }
} }
static void send_sync(void) static void send_sync(void)
{ {
send_int(0x5a5a5a5a); send_int(0x5a5a5a5a);
} }
int ident_and_download_kernel(void *buffer, int maxlength) int ident_and_download_kernel(void *buffer, int maxlength)
{ {
int length; int length;
unsigned int crc; unsigned int crc;
int i; int i;
char msgtype; char msgtype;
unsigned char *_buffer = buffer; unsigned char *_buffer = buffer;
while(1) { while(1) {
receive_sync(); receive_sync();
msgtype = receive_char(); msgtype = receive_char();
if(msgtype == MSGTYPE_REQUEST_IDENT) { if(msgtype == MSGTYPE_REQUEST_IDENT) {
send_int(0x41524f52); /* "AROR" - ARTIQ runtime on OpenRISC */ send_int(0x41524f52); /* "AROR" - ARTIQ runtime on OpenRISC */
send_int(1000000000000LL/identifier_frequency_read()); /* RTIO clock period in picoseconds */ send_int(1000000000000LL/identifier_frequency_read()); /* RTIO clock period in picoseconds */
} else if(msgtype == MSGTYPE_LOAD_KERNEL) { } else if(msgtype == MSGTYPE_LOAD_KERNEL) {
length = receive_int(); length = receive_int();
if(length > maxlength) { if(length > maxlength) {
send_char(0x4c); /* Incorrect length */ send_char(0x4c); /* Incorrect length */
return -1; return -1;
} }
crc = receive_int(); crc = receive_int();
for(i=0;i<length;i++) for(i=0;i<length;i++)
_buffer[i] = receive_char(); _buffer[i] = receive_char();
if(crc32(buffer, length) != crc) { if(crc32(buffer, length) != crc) {
send_char(0x43); /* CRC failed */ send_char(0x43); /* CRC failed */
return -1; return -1;
} }
send_char(0x4f); /* kernel reception OK */ send_char(0x4f); /* kernel reception OK */
return length; return length;
} else } else
return -1; return -1;
} }
} }
int rpc(int rpc_num, int n_args, ...) int rpc(int rpc_num, int n_args, ...)
{ {
send_sync(); send_sync();
send_char(MSGTYPE_RPC_REQUEST); send_char(MSGTYPE_RPC_REQUEST);
send_sint(rpc_num); send_sint(rpc_num);
send_char(n_args); send_char(n_args);
va_list args; va_list args;
va_start(args, n_args); va_start(args, n_args);
while(n_args--) while(n_args--)
send_int(va_arg(args, int)); send_int(va_arg(args, int));
va_end(args); va_end(args);
return receive_int(); return receive_int();
} }
void kernel_finished(void) void kernel_finished(void)
{ {
send_sync(); send_sync();
send_char(MSGTYPE_KERNEL_FINISHED); send_char(MSGTYPE_KERNEL_FINISHED);
} }

View File

@ -10,33 +10,33 @@
#define DDS_GPIO 0x41 #define DDS_GPIO 0x41
#define DDS_READ(addr) \ #define DDS_READ(addr) \
MMPTR(0xb0000000 + (addr)*4) MMPTR(0xb0000000 + (addr)*4)
#define DDS_WRITE(addr, data) \ #define DDS_WRITE(addr, data) \
MMPTR(0xb0000000 + (addr)*4) = data MMPTR(0xb0000000 + (addr)*4) = data
void dds_init(void) void dds_init(void)
{ {
int i; int i;
DDS_WRITE(DDS_GPIO, 1 << 7); DDS_WRITE(DDS_GPIO, 1 << 7);
for(i=0;i<8;i++) { for(i=0;i<8;i++) {
DDS_WRITE(DDS_GPIO, i); DDS_WRITE(DDS_GPIO, i);
DDS_WRITE(0x00, 0x78); DDS_WRITE(0x00, 0x78);
DDS_WRITE(0x01, 0x00); DDS_WRITE(0x01, 0x00);
DDS_WRITE(0x02, 0x00); DDS_WRITE(0x02, 0x00);
DDS_WRITE(0x03, 0x00); DDS_WRITE(0x03, 0x00);
DDS_WRITE(DDS_FUD, 0); DDS_WRITE(DDS_FUD, 0);
} }
} }
void dds_program(int channel, int ftw) void dds_program(int channel, int ftw)
{ {
DDS_WRITE(DDS_GPIO, channel); DDS_WRITE(DDS_GPIO, channel);
DDS_WRITE(DDS_FTW0, ftw & 0xff); DDS_WRITE(DDS_FTW0, ftw & 0xff);
DDS_WRITE(DDS_FTW1, (ftw >> 8) & 0xff); DDS_WRITE(DDS_FTW1, (ftw >> 8) & 0xff);
DDS_WRITE(DDS_FTW2, (ftw >> 16) & 0xff); DDS_WRITE(DDS_FTW2, (ftw >> 16) & 0xff);
DDS_WRITE(DDS_FTW3, (ftw >> 24) & 0xff); DDS_WRITE(DDS_FTW3, (ftw >> 24) & 0xff);
DDS_WRITE(DDS_FUD, 0); DDS_WRITE(DDS_FUD, 0);
} }

View File

@ -6,27 +6,27 @@
#define EI_NIDENT 16 #define EI_NIDENT 16
struct elf32_ehdr { struct elf32_ehdr {
unsigned char ident[EI_NIDENT]; /* ident bytes */ unsigned char ident[EI_NIDENT]; /* ident bytes */
unsigned short type; /* file type */ unsigned short type; /* file type */
unsigned short machine; /* target machine */ unsigned short machine; /* target machine */
unsigned int version; /* file version */ unsigned int version; /* file version */
unsigned int entry; /* start address */ unsigned int entry; /* start address */
unsigned int phoff; /* phdr file offset */ unsigned int phoff; /* phdr file offset */
unsigned int shoff; /* shdr file offset */ unsigned int shoff; /* shdr file offset */
unsigned int flags; /* file flags */ unsigned int flags; /* file flags */
unsigned short ehsize; /* sizeof ehdr */ unsigned short ehsize; /* sizeof ehdr */
unsigned short phentsize; /* sizeof phdr */ unsigned short phentsize; /* sizeof phdr */
unsigned short phnum; /* number phdrs */ unsigned short phnum; /* number phdrs */
unsigned short shentsize; /* sizeof shdr */ unsigned short shentsize; /* sizeof shdr */
unsigned short shnum; /* number shdrs */ unsigned short shnum; /* number shdrs */
unsigned short shstrndx; /* shdr string index */ unsigned short shstrndx; /* shdr string index */
} __attribute__((packed)); } __attribute__((packed));
static const unsigned char elf_magic_header[] = { static const unsigned char elf_magic_header[] = {
0x7f, 0x45, 0x4c, 0x46, /* 0x7f, 'E', 'L', 'F' */ 0x7f, 0x45, 0x4c, 0x46, /* 0x7f, 'E', 'L', 'F' */
0x01, /* Only 32-bit objects. */ 0x01, /* Only 32-bit objects. */
0x02, /* Only big-endian. */ 0x02, /* Only big-endian. */
0x01, /* Only ELF version 1. */ 0x01, /* Only ELF version 1. */
}; };
#define ET_NONE 0 /* Unknown type. */ #define ET_NONE 0 /* Unknown type. */
@ -38,26 +38,26 @@ static const unsigned char elf_magic_header[] = {
#define EM_OR1K 0x005c #define EM_OR1K 0x005c
struct elf32_shdr { struct elf32_shdr {
unsigned int name; /* section name */ unsigned int name; /* section name */
unsigned int type; /* SHT_... */ unsigned int type; /* SHT_... */
unsigned int flags; /* SHF_... */ unsigned int flags; /* SHF_... */
unsigned int addr; /* virtual address */ unsigned int addr; /* virtual address */
unsigned int offset; /* file offset */ unsigned int offset; /* file offset */
unsigned int size; /* section size */ unsigned int size; /* section size */
unsigned int link; /* misc info */ unsigned int link; /* misc info */
unsigned int info; /* misc info */ unsigned int info; /* misc info */
unsigned int addralign; /* memory alignment */ unsigned int addralign; /* memory alignment */
unsigned int entsize; /* entry size if table */ unsigned int entsize; /* entry size if table */
} __attribute__((packed)); } __attribute__((packed));
struct elf32_name { struct elf32_name {
char name[12]; char name[12];
} __attribute__((packed)); } __attribute__((packed));
struct elf32_rela { struct elf32_rela {
unsigned int offset; /* Location to be relocated. */ unsigned int offset; /* Location to be relocated. */
unsigned int info; /* Relocation type and symbol index. */ unsigned int info; /* Relocation type and symbol index. */
int addend; /* Addend. */ int addend; /* Addend. */
} __attribute__((packed)); } __attribute__((packed));
#define ELF32_R_SYM(info) ((info) >> 8) #define ELF32_R_SYM(info) ((info) >> 8)
@ -66,151 +66,151 @@ struct elf32_rela {
#define R_OR1K_INSN_REL_26 6 #define R_OR1K_INSN_REL_26 6
struct elf32_sym { struct elf32_sym {
unsigned int name; /* String table index of name. */ unsigned int name; /* String table index of name. */
unsigned int value; /* Symbol value. */ unsigned int value; /* Symbol value. */
unsigned int size; /* Size of associated object. */ unsigned int size; /* Size of associated object. */
unsigned char info; /* Type and binding information. */ unsigned char info; /* Type and binding information. */
unsigned char other; /* Reserved (not used). */ unsigned char other; /* Reserved (not used). */
unsigned short shndx; /* Section index of symbol. */ unsigned short shndx; /* Section index of symbol. */
} __attribute__((packed)); } __attribute__((packed));
#define SANITIZE_OFFSET_SIZE(offset, size) \ #define SANITIZE_OFFSET_SIZE(offset, size) \
if(offset > 0x10000000) { \ if(offset > 0x10000000) { \
printf("Incorrect offset in ELF data"); \ printf("Incorrect offset in ELF data"); \
return 0; \ return 0; \
} \ } \
if((offset + size) > elf_length) { \ if((offset + size) > elf_length) { \
printf("Attempted to access past the end of ELF data"); \ printf("Attempted to access past the end of ELF data"); \
return 0; \ return 0; \
} }
#define GET_POINTER_SAFE(target, target_type, offset) \ #define GET_POINTER_SAFE(target, target_type, offset) \
SANITIZE_OFFSET_SIZE(offset, sizeof(target_type)); \ SANITIZE_OFFSET_SIZE(offset, sizeof(target_type)); \
target = (target_type *)((char *)elf_data + offset) target = (target_type *)((char *)elf_data + offset)
void *find_symbol(const struct symbol *symbols, const char *name) void *find_symbol(const struct symbol *symbols, const char *name)
{ {
int i; int i;
i = 0; i = 0;
while((symbols[i].name != NULL) && (strcmp(symbols[i].name, name) != 0)) while((symbols[i].name != NULL) && (strcmp(symbols[i].name, name) != 0))
i++; i++;
return symbols[i].target; return symbols[i].target;
} }
static int fixup(void *dest, int dest_length, struct elf32_rela *rela, void *target) static int fixup(void *dest, int dest_length, struct elf32_rela *rela, void *target)
{ {
int type, offset; int type, offset;
unsigned int *_dest = dest; unsigned int *_dest = dest;
unsigned int *_target = target; unsigned int *_target = target;
type = ELF32_R_TYPE(rela->info); type = ELF32_R_TYPE(rela->info);
offset = rela->offset/4; offset = rela->offset/4;
if(type == R_OR1K_INSN_REL_26) { if(type == R_OR1K_INSN_REL_26) {
int val; int val;
val = _target - (_dest + offset); val = _target - (_dest + offset);
_dest[offset] = (_dest[offset] & 0xfc000000) | (val & 0x03ffffff); _dest[offset] = (_dest[offset] & 0xfc000000) | (val & 0x03ffffff);
} else } else
printf("Unsupported relocation type: %d\n", type); printf("Unsupported relocation type: %d\n", type);
return 1; return 1;
} }
int load_elf(symbol_resolver resolver, void *elf_data, int elf_length, void *dest, int dest_length) int load_elf(symbol_resolver resolver, void *elf_data, int elf_length, void *dest, int dest_length)
{ {
struct elf32_ehdr *ehdr; struct elf32_ehdr *ehdr;
struct elf32_shdr *strtable; struct elf32_shdr *strtable;
unsigned int shdrptr; unsigned int shdrptr;
int i; int i;
unsigned int textoff, textsize; unsigned int textoff, textsize;
unsigned int textrelaoff, textrelasize; unsigned int textrelaoff, textrelasize;
unsigned int symtaboff, symtabsize; unsigned int symtaboff, symtabsize;
unsigned int strtaboff, strtabsize; unsigned int strtaboff, strtabsize;
/* validate ELF */ /* validate ELF */
GET_POINTER_SAFE(ehdr, struct elf32_ehdr, 0); GET_POINTER_SAFE(ehdr, struct elf32_ehdr, 0);
if(memcmp(ehdr->ident, elf_magic_header, sizeof(elf_magic_header)) != 0) { if(memcmp(ehdr->ident, elf_magic_header, sizeof(elf_magic_header)) != 0) {
printf("Incorrect ELF header\n"); printf("Incorrect ELF header\n");
return 0; return 0;
} }
if(ehdr->type != ET_REL) { if(ehdr->type != ET_REL) {
printf("ELF is not relocatable\n"); printf("ELF is not relocatable\n");
return 0; return 0;
} }
if(ehdr->machine != EM_OR1K) { if(ehdr->machine != EM_OR1K) {
printf("ELF is for a different machine\n"); printf("ELF is for a different machine\n");
return 0; return 0;
} }
/* extract section info */ /* extract section info */
GET_POINTER_SAFE(strtable, struct elf32_shdr, ehdr->shoff + ehdr->shentsize*ehdr->shstrndx); GET_POINTER_SAFE(strtable, struct elf32_shdr, ehdr->shoff + ehdr->shentsize*ehdr->shstrndx);
textoff = textsize = 0; textoff = textsize = 0;
textrelaoff = textrelasize = 0; textrelaoff = textrelasize = 0;
symtaboff = symtabsize = 0; symtaboff = symtabsize = 0;
strtaboff = strtabsize = 0; strtaboff = strtabsize = 0;
shdrptr = ehdr->shoff; shdrptr = ehdr->shoff;
for(i=0;i<ehdr->shnum;i++) { for(i=0;i<ehdr->shnum;i++) {
struct elf32_shdr *shdr; struct elf32_shdr *shdr;
struct elf32_name *name; struct elf32_name *name;
GET_POINTER_SAFE(shdr, struct elf32_shdr, shdrptr); GET_POINTER_SAFE(shdr, struct elf32_shdr, shdrptr);
GET_POINTER_SAFE(name, struct elf32_name, strtable->offset + shdr->name); GET_POINTER_SAFE(name, struct elf32_name, strtable->offset + shdr->name);
if(strncmp(name->name, ".text", 5) == 0) { if(strncmp(name->name, ".text", 5) == 0) {
textoff = shdr->offset; textoff = shdr->offset;
textsize = shdr->size; textsize = shdr->size;
} else if(strncmp(name->name, ".rela.text", 10) == 0) { } else if(strncmp(name->name, ".rela.text", 10) == 0) {
textrelaoff = shdr->offset; textrelaoff = shdr->offset;
textrelasize = shdr->size; textrelasize = shdr->size;
} else if(strncmp(name->name, ".symtab", 7) == 0) { } else if(strncmp(name->name, ".symtab", 7) == 0) {
symtaboff = shdr->offset; symtaboff = shdr->offset;
symtabsize = shdr->size; symtabsize = shdr->size;
} else if(strncmp(name->name, ".strtab", 7) == 0) { } else if(strncmp(name->name, ".strtab", 7) == 0) {
strtaboff = shdr->offset; strtaboff = shdr->offset;
strtabsize = shdr->size; strtabsize = shdr->size;
} }
shdrptr += ehdr->shentsize; shdrptr += ehdr->shentsize;
} }
SANITIZE_OFFSET_SIZE(textoff, textsize); SANITIZE_OFFSET_SIZE(textoff, textsize);
SANITIZE_OFFSET_SIZE(textrelaoff, textrelasize); SANITIZE_OFFSET_SIZE(textrelaoff, textrelasize);
SANITIZE_OFFSET_SIZE(symtaboff, symtabsize); SANITIZE_OFFSET_SIZE(symtaboff, symtabsize);
SANITIZE_OFFSET_SIZE(strtaboff, strtabsize); SANITIZE_OFFSET_SIZE(strtaboff, strtabsize);
/* load .text section */ /* load .text section */
if(textsize > dest_length) { if(textsize > dest_length) {
printf(".text section is too large\n"); printf(".text section is too large\n");
return 0; return 0;
} }
memcpy(dest, (char *)elf_data + textoff, textsize); memcpy(dest, (char *)elf_data + textoff, textsize);
/* process .text relocations */ /* process .text relocations */
for(i=0;i<textrelasize;i+=sizeof(struct elf32_rela)) { for(i=0;i<textrelasize;i+=sizeof(struct elf32_rela)) {
struct elf32_rela *rela; struct elf32_rela *rela;
struct elf32_sym *sym; struct elf32_sym *sym;
char *name; char *name;
GET_POINTER_SAFE(rela, struct elf32_rela, textrelaoff + i); GET_POINTER_SAFE(rela, struct elf32_rela, textrelaoff + i);
GET_POINTER_SAFE(sym, struct elf32_sym, symtaboff + sizeof(struct elf32_sym)*ELF32_R_SYM(rela->info)); GET_POINTER_SAFE(sym, struct elf32_sym, symtaboff + sizeof(struct elf32_sym)*ELF32_R_SYM(rela->info));
if(sym->name != 0) { if(sym->name != 0) {
void *target; void *target;
name = (char *)elf_data + strtaboff + sym->name; name = (char *)elf_data + strtaboff + sym->name;
target = resolver(name); target = resolver(name);
if(target == NULL) { if(target == NULL) {
printf("Undefined symbol: %s\n", name); printf("Undefined symbol: %s\n", name);
return 0; return 0;
} }
if(!fixup(dest, dest_length, rela, target)) if(!fixup(dest, dest_length, rela, target))
return 0; return 0;
} else { } else {
printf("Unsupported relocation\n"); printf("Unsupported relocation\n");
return 0; return 0;
} }
} }
return 1; return 1;
} }

View File

@ -2,8 +2,8 @@
#define __ELF_LOADER_H #define __ELF_LOADER_H
struct symbol { struct symbol {
char *name; char *name;
void *target; void *target;
}; };
void *find_symbol(const struct symbol *symbols, const char *name); void *find_symbol(const struct symbol *symbols, const char *name);

View File

@ -4,11 +4,11 @@
void gpio_set(int channel, int value) void gpio_set(int channel, int value)
{ {
static int csr_value; static int csr_value;
if(value) if(value)
csr_value |= 1 << channel; csr_value |= 1 << channel;
else else
csr_value &= ~(1 << channel); csr_value &= ~(1 << channel);
leds_out_write(csr_value); leds_out_write(csr_value);
} }

View File

@ -5,10 +5,10 @@
void isr(void); void isr(void);
void isr(void) void isr(void)
{ {
unsigned int irqs; unsigned int irqs;
irqs = irq_pending() & irq_getmask(); irqs = irq_pending() & irq_getmask();
if(irqs & (1 << UART_INTERRUPT)) if(irqs & (1 << UART_INTERRUPT))
uart_isr(); uart_isr();
} }

View File

@ -13,29 +13,29 @@ typedef void (*kernel_function)(void);
int main(void) int main(void)
{ {
unsigned char kbuf[256*1024]; unsigned char kbuf[256*1024];
unsigned char kcode[256*1024]; unsigned char kcode[256*1024];
kernel_function k = (kernel_function)kcode; kernel_function k = (kernel_function)kcode;
int length; int length;
irq_setmask(0); irq_setmask(0);
irq_setie(1); irq_setie(1);
uart_init(); uart_init();
puts("ARTIQ runtime built "__DATE__" "__TIME__"\n"); puts("ARTIQ runtime built "__DATE__" "__TIME__"\n");
while(1) { while(1) {
length = ident_and_download_kernel(kbuf, sizeof(kbuf)); length = ident_and_download_kernel(kbuf, sizeof(kbuf));
if(length > 0) { if(length > 0) {
if(load_elf(resolve_symbol, kbuf, length, kcode, sizeof(kcode))) { if(load_elf(resolve_symbol, kbuf, length, kcode, sizeof(kcode))) {
rtio_init(); rtio_init();
dds_init(); dds_init();
flush_cpu_icache(); flush_cpu_icache();
k(); k();
kernel_finished(); kernel_finished();
} }
} }
} }
return 0; return 0;
} }

View File

@ -4,21 +4,21 @@
void rtio_init(void) void rtio_init(void)
{ {
rtio_reset_write(1); rtio_reset_write(1);
} }
void rtio_set(long long int timestamp, int channel, int value) void rtio_set(long long int timestamp, int channel, int value)
{ {
rtio_reset_write(0); rtio_reset_write(0);
rtio_chan_sel_write(channel); rtio_chan_sel_write(channel);
rtio_o_timestamp_write(timestamp); rtio_o_timestamp_write(timestamp);
rtio_o_value_write(value); rtio_o_value_write(value);
while(!rtio_o_writable_read()); while(!rtio_o_writable_read());
rtio_o_we_write(1); rtio_o_we_write(1);
} }
void rtio_sync(int channel) void rtio_sync(int channel)
{ {
rtio_chan_sel_write(channel); rtio_chan_sel_write(channel);
while(rtio_o_level_read() != 0); while(rtio_o_level_read() != 0);
} }

View File

@ -8,34 +8,34 @@
#include "symbols.h" #include "symbols.h"
static const struct symbol syscalls[] = { static const struct symbol syscalls[] = {
{"rpc", rpc}, {"rpc", rpc},
{"gpio_set", gpio_set}, {"gpio_set", gpio_set},
{"rtio_set", rtio_set}, {"rtio_set", rtio_set},
{"rtio_sync", rtio_sync}, {"rtio_sync", rtio_sync},
{"dds_program", dds_program}, {"dds_program", dds_program},
{NULL, NULL} {NULL, NULL}
}; };
static long long int gcd64(long long int a, long long int b) static long long int gcd64(long long int a, long long int b)
{ {
long long int c; long long int c;
while(a) { while(a) {
c = a; c = a;
a = b % a; a = b % a;
b = c; b = c;
} }
return b; return b;
} }
static const struct symbol arithmetic[] = { static const struct symbol arithmetic[] = {
{"__gcd64", gcd64}, {"__gcd64", gcd64},
{NULL, NULL} {NULL, NULL}
}; };
void *resolve_symbol(const char *name) void *resolve_symbol(const char *name)
{ {
if(strncmp(name, "__syscall_", 10) == 0) if(strncmp(name, "__syscall_", 10) == 0)
return find_symbol(syscalls, name + 10); return find_symbol(syscalls, name + 10);
return find_symbol(arithmetic, name); return find_symbol(arithmetic, name);
} }

View File

@ -7,44 +7,44 @@ from targets.ppro import BaseSoC
from artiqlib import rtio, ad9858 from artiqlib import rtio, ad9858
_tester_io = [ _tester_io = [
("user_led", 1, Pins("B:7"), IOStandard("LVTTL")), ("user_led", 1, Pins("B:7"), IOStandard("LVTTL")),
("ttl", 0, Pins("C:13"), IOStandard("LVTTL")), ("ttl", 0, Pins("C:13"), IOStandard("LVTTL")),
("ttl", 1, Pins("C:11"), IOStandard("LVTTL")), ("ttl", 1, Pins("C:11"), IOStandard("LVTTL")),
("ttl", 2, Pins("C:10"), IOStandard("LVTTL")), ("ttl", 2, Pins("C:10"), IOStandard("LVTTL")),
("ttl", 3, Pins("C:9"), IOStandard("LVTTL")), ("ttl", 3, Pins("C:9"), IOStandard("LVTTL")),
("ttl_tx_en", 0, Pins("A:9"), IOStandard("LVTTL")), ("ttl_tx_en", 0, Pins("A:9"), IOStandard("LVTTL")),
("dds", 0, ("dds", 0,
Subsignal("a", Pins("A:5 B:10 A:6 B:9 A:7 B:8")), Subsignal("a", Pins("A:5 B:10 A:6 B:9 A:7 B:8")),
Subsignal("d", Pins("A:12 B:3 A:13 B:2 A:14 B:1 A:15 B:0")), Subsignal("d", Pins("A:12 B:3 A:13 B:2 A:14 B:1 A:15 B:0")),
Subsignal("sel", Pins("A:2 B:14 A:1 B:15 A:0")), Subsignal("sel", Pins("A:2 B:14 A:1 B:15 A:0")),
Subsignal("p", Pins("A:8 B:12")), Subsignal("p", Pins("A:8 B:12")),
Subsignal("fud_n", Pins("B:11")), Subsignal("fud_n", Pins("B:11")),
Subsignal("wr_n", Pins("A:4")), Subsignal("wr_n", Pins("A:4")),
Subsignal("rd_n", Pins("B:13")), Subsignal("rd_n", Pins("B:13")),
Subsignal("rst_n", Pins("A:3")), Subsignal("rst_n", Pins("A:3")),
IOStandard("LVTTL")), IOStandard("LVTTL")),
] ]
class ARTIQMiniSoC(BaseSoC): class ARTIQMiniSoC(BaseSoC):
csr_map = { csr_map = {
"rtio": 10 "rtio": 10
} }
csr_map.update(BaseSoC.csr_map) csr_map.update(BaseSoC.csr_map)
def __init__(self, platform, cpu_type="or1k", **kwargs): def __init__(self, platform, cpu_type="or1k", **kwargs):
BaseSoC.__init__(self, platform, cpu_type=cpu_type, **kwargs) BaseSoC.__init__(self, platform, cpu_type=cpu_type, **kwargs)
platform.add_extension(_tester_io) platform.add_extension(_tester_io)
self.submodules.leds = gpio.GPIOOut(Cat(platform.request("user_led", 0), self.submodules.leds = gpio.GPIOOut(Cat(platform.request("user_led", 0),
platform.request("user_led", 1))) platform.request("user_led", 1)))
self.comb += platform.request("ttl_tx_en").eq(1) self.comb += platform.request("ttl_tx_en").eq(1)
rtio_pads = [platform.request("ttl", i) for i in range(4)] rtio_pads = [platform.request("ttl", i) for i in range(4)]
self.submodules.rtiophy = rtio.phy.SimplePHY(rtio_pads, self.submodules.rtiophy = rtio.phy.SimplePHY(rtio_pads,
{rtio_pads[1], rtio_pads[2], rtio_pads[3]}) {rtio_pads[1], rtio_pads[2], rtio_pads[3]})
self.submodules.rtio = rtio.RTIO(self.rtiophy) self.submodules.rtio = rtio.RTIO(self.rtiophy)
self.submodules.dds = ad9858.AD9858(platform.request("dds")) self.submodules.dds = ad9858.AD9858(platform.request("dds"))
self.add_wb_slave(lambda a: a[26:29] == 3, self.dds.bus) self.add_wb_slave(lambda a: a[26:29] == 3, self.dds.bus)
default_subtarget = ARTIQMiniSoC default_subtarget = ARTIQMiniSoC