forked from M-Labs/artiq
1
0
Fork 0
This commit is contained in:
Sebastien Bourdeauducq 2014-09-05 12:03:22 +08:00
parent 7e9df82e37
commit 4915b4b5aa
43 changed files with 3179 additions and 2915 deletions

View File

@ -1,8 +1,10 @@
import ast, operator import ast
import operator
from artiq.compiler.tools import * from artiq.compiler.tools import *
from artiq.language.core import int64, round64 from artiq.language.core import int64, round64
_ast_unops = { _ast_unops = {
ast.Invert: operator.inv, ast.Invert: operator.inv,
ast.Not: operator.not_, ast.Not: operator.not_,
@ -10,6 +12,7 @@ _ast_unops = {
ast.USub: operator.neg ast.USub: operator.neg
} }
_ast_binops = { _ast_binops = {
ast.Add: operator.add, ast.Add: operator.add,
ast.Sub: operator.sub, ast.Sub: operator.sub,
@ -25,6 +28,7 @@ _ast_binops = {
ast.BitAnd: operator.and_ ast.BitAnd: operator.and_
} }
class _ConstantFolder(ast.NodeTransformer): class _ConstantFolder(ast.NodeTransformer):
def visit_UnaryOp(self, node): def visit_UnaryOp(self, node):
self.generic_visit(node) self.generic_visit(node)
@ -77,5 +81,6 @@ class _ConstantFolder(ast.NodeTransformer):
else: else:
return node return node
def fold_constants(node): def fold_constants(node):
_ConstantFolder().visit(node) _ConstantFolder().visit(node)

View File

@ -1,21 +1,27 @@
from collections import namedtuple, defaultdict from collections import namedtuple, defaultdict
from fractions import Fraction from fractions import Fraction
import inspect, textwrap, ast import inspect
import textwrap
import ast
from artiq.compiler.tools import eval_ast, value_to_ast from artiq.compiler.tools import eval_ast, value_to_ast
from artiq.language import core as core_language from artiq.language import core as core_language
from artiq.language import units from artiq.language import units
_UserVariable = namedtuple("_UserVariable", "name") _UserVariable = namedtuple("_UserVariable", "name")
def _is_in_attr_list(obj, attr, al): def _is_in_attr_list(obj, attr, al):
if not hasattr(obj, al): if not hasattr(obj, al):
return False return False
return attr in getattr(obj, al).split() return attr in getattr(obj, al).split()
class _ReferenceManager: class _ReferenceManager:
def __init__(self): def __init__(self):
# (id(obj), funcname, local) -> _UserVariable(name) / ast / constant_object # (id(obj), funcname, local)
# -> _UserVariable(name) / ast / constant_object
# local is None for kernel attributes # local is None for kernel attributes
self.to_inlined = dict() self.to_inlined = dict()
# inlined_name -> use_count # inlined_name -> use_count
@ -26,9 +32,9 @@ class _ReferenceManager:
# reserved names # reserved names
for kg in core_language.kernel_globals: for kg in core_language.kernel_globals:
self.use_count[kg] = 1 self.use_count[kg] = 1
for name in "int", "round", "int64", "round64", \ for name in ("int", "round", "int64", "round64",
"range", "Fraction", "Quantity", \ "range", "Fraction", "Quantity",
"s_unit", "Hz_unit", "microcycle_unit": "s_unit", "Hz_unit", "microcycle_unit"):
self.use_count[name] = 1 self.use_count[name] = 1
def new_name(self, base_name): def new_name(self, base_name):
@ -63,7 +69,9 @@ class _ReferenceManager:
else: else:
if _is_in_attr_list(value, ref.attr, "kernel_attr_ro"): if _is_in_attr_list(value, ref.attr, "kernel_attr_ro"):
if store: if store:
raise TypeError("Attempted to assign to read-only kernel attribute") raise TypeError(
"Attempted to assign to read-only"
" kernel attribute")
return getattr(value, ref.attr) return getattr(value, ref.attr)
if _is_in_attr_list(value, ref.attr, "kernel_attr"): if _is_in_attr_list(value, ref.attr, "kernel_attr"):
key = (id(value), ref.attr, None) key = (id(value), ref.attr, None)
@ -76,7 +84,9 @@ class _ReferenceManager:
self.to_inlined[key] = ival self.to_inlined[key] = ival
a = value_to_ast(getattr(value, ref.attr)) a = value_to_ast(getattr(value, ref.attr))
if a is None: if a is None:
raise NotImplementedError("Cannot represent initial value of kernel attribute") raise NotImplementedError(
"Cannot represent initial value"
" of kernel attribute")
self.kernel_attr_init.append(ast.Assign( self.kernel_attr_init.append(ast.Assign(
[ast.Name(iname, ast.Store())], a)) [ast.Name(iname, ast.Store())], a))
return ival return ival
@ -92,12 +102,14 @@ class _ReferenceManager:
self.to_inlined[(id(obj), funcname, name)] = value self.to_inlined[(id(obj), funcname, name)] = value
def get_constants(self, r_obj, r_funcname): def get_constants(self, r_obj, r_funcname):
return {local: v for (objid, funcname, local), v return {
local: v for (objid, funcname, local), v
in self.to_inlined.items() in self.to_inlined.items()
if objid == id(r_obj) if objid == id(r_obj)
and funcname == r_funcname and funcname == r_funcname
and not isinstance(v, (_UserVariable, ast.AST))} and not isinstance(v, (_UserVariable, ast.AST))}
_embeddable_calls = { _embeddable_calls = {
core_language.delay, core_language.at, core_language.now, core_language.delay, core_language.at, core_language.now,
core_language.syscall, core_language.syscall,
@ -105,6 +117,7 @@ _embeddable_calls = {
Fraction, units.Quantity Fraction, units.Quantity
} }
class _ReferenceReplacer(ast.NodeTransformer): class _ReferenceReplacer(ast.NodeTransformer):
def __init__(self, core, rm, obj, funcname): def __init__(self, core, rm, obj, funcname):
self.core = core self.core = core
@ -122,11 +135,13 @@ class _ReferenceReplacer(ast.NodeTransformer):
newnode = ival newnode = ival
else: else:
if store: if store:
raise NotImplementedError("Cannot turn object into user variable") raise NotImplementedError(
"Cannot turn object into user variable")
else: else:
newnode = value_to_ast(ival) newnode = value_to_ast(ival)
if newnode is None: if newnode is None:
raise NotImplementedError("Cannot represent inlined value") raise NotImplementedError(
"Cannot represent inlined value")
return ast.copy_location(newnode, node) return ast.copy_location(newnode, node)
visit_Name = visit_ref visit_Name = visit_ref
@ -143,9 +158,12 @@ class _ReferenceReplacer(ast.NodeTransformer):
ast.Call(func=new_func, args=new_args, ast.Call(func=new_func, args=new_args,
keywords=[], starargs=None, kwargs=None), keywords=[], starargs=None, kwargs=None),
node) node)
elif hasattr(func, "k_function_info") and getattr(func.__self__, func.k_function_info.core_name) is self.core: elif (hasattr(func, "k_function_info")
and getattr(func.__self__, func.k_function_info.core_name)
is self.core):
args = [func.__self__] + new_args args = [func.__self__] + new_args
inlined, _ = inline(self.core, func.k_function_info.k_function, args, dict(), self.rm) inlined, _ = inline(self.core, func.k_function_info.k_function,
args, dict(), self.rm)
return inlined.body return inlined.body
else: else:
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])] args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
@ -168,11 +186,13 @@ class _ReferenceReplacer(ast.NodeTransformer):
return node return node
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]) node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[],
kw_defaults=[], kwarg=None, defaults=[])
node.decorator_list = [] node.decorator_list = []
self.generic_visit(node) self.generic_visit(node)
return node return node
class _ListReadOnlyParams(ast.NodeVisitor): class _ListReadOnlyParams(ast.NodeVisitor):
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
if hasattr(self, "read_only_params"): if hasattr(self, "read_only_params"):
@ -187,11 +207,13 @@ class _ListReadOnlyParams(ast.NodeVisitor):
except KeyError: except KeyError:
pass pass
def _list_read_only_params(funcdef): def _list_read_only_params(funcdef):
lrp = _ListReadOnlyParams() lrp = _ListReadOnlyParams()
lrp.visit(funcdef) lrp.visit(funcdef)
return lrp.read_only_params return lrp.read_only_params
def _initialize_function_params(funcdef, k_args, k_kwargs, rm): def _initialize_function_params(funcdef, k_args, k_kwargs, rm):
obj = k_args[0] obj = k_args[0]
funcname = funcdef.name funcname = funcdef.name
@ -207,6 +229,7 @@ def _initialize_function_params(funcdef, k_args, k_kwargs, rm):
param_init.append(ast.Assign(targets=[target], value=value)) param_init.append(ast.Assign(targets=[target], value=value))
return param_init return param_init
def inline(core, k_function, k_args, k_kwargs, rm=None): def inline(core, k_function, k_args, k_kwargs, rm=None):
init_kernel_attr = rm is None init_kernel_attr = rm is None
if rm is None: if rm is None:
@ -225,5 +248,6 @@ def inline(core, k_function, k_args, k_kwargs, rm=None):
if init_kernel_attr: if init_kernel_attr:
funcdef.body[0:0] = rm.kernel_attr_init funcdef.body[0:0] = rm.kernel_attr_init
r_rpc_map = dict((rpc_num, rpc_fun) for rpc_fun, rpc_num in rm.rpc_map.items()) r_rpc_map = dict((rpc_num, rpc_fun)
for rpc_fun, rpc_num in rm.rpc_map.items())
return funcdef, r_rpc_map return funcdef, r_rpc_map

View File

@ -1,7 +1,9 @@
import ast, types import ast
import types
from artiq.compiler.tools import * from artiq.compiler.tools import *
# -1 statement duration could not be pre-determined # -1 statement duration could not be pre-determined
# 0 statement has no effect on timeline # 0 statement has no effect on timeline
# >0 statement is a static delay that advances the timeline # >0 statement is a static delay that advances the timeline
@ -10,7 +12,8 @@ def _get_duration(stmt):
if isinstance(stmt, (ast.Expr, ast.Assign)): if isinstance(stmt, (ast.Expr, ast.Assign)):
return _get_duration(stmt.value) return _get_duration(stmt.value)
elif isinstance(stmt, ast.If): elif isinstance(stmt, ast.If):
if all(_get_duration(s) == 0 for s in stmt.body) and all(_get_duration(s) == 0 for s in stmt.orelse): if (all(_get_duration(s) == 0 for s in stmt.body)
and all(_get_duration(s) == 0 for s in stmt.orelse)):
return 0 return 0
else: else:
return -1 return -1
@ -27,6 +30,7 @@ def _get_duration(stmt):
else: else:
return 0 return 0
def _interleave_timelines(timelines): def _interleave_timelines(timelines):
r = [] r = []
@ -38,7 +42,8 @@ def _interleave_timelines(timelines):
except StopIteration: except StopIteration:
pass pass
else: else:
current_stmts.append(types.SimpleNamespace(delay=_get_duration(stmt), stmt=stmt, it=it)) current_stmts.append(types.SimpleNamespace(
delay=_get_duration(stmt), stmt=stmt, it=it))
while current_stmts: while current_stmts:
dt = min(stmt.delay for stmt in current_stmts) dt = min(stmt.delay for stmt in current_stmts)
@ -52,7 +57,8 @@ def _interleave_timelines(timelines):
if stmt.delay == 0: if stmt.delay == 0:
ref_stmt = stmt.stmt ref_stmt = stmt.stmt
delay_stmt = ast.copy_location( delay_stmt = ast.copy_location(
ast.Expr(ast.Call(func=ast.Name("delay", ast.Load()), ast.Expr(ast.Call(
func=ast.Name("delay", ast.Load()),
args=[value_to_ast(dt)], args=[value_to_ast(dt)],
keywords=[], starargs=[], kwargs=[])), keywords=[], starargs=[], kwargs=[])),
ref_stmt) ref_stmt)
@ -76,6 +82,7 @@ def _interleave_timelines(timelines):
return r return r
def _interleave_stmts(stmts): def _interleave_stmts(stmts):
replacements = [] replacements = []
for stmt_i, stmt in enumerate(stmts): for stmt_i, stmt in enumerate(stmts):
@ -101,5 +108,6 @@ def _interleave_stmts(stmts):
stmts[offset+location:offset+location+1] = new_stmts stmts[offset+location:offset+location+1] = new_stmts
offset += len(new_stmts) - 1 offset += len(new_stmts) - 1
def interleave(funcdef): def interleave(funcdef):
_interleave_stmts(funcdef.body) _interleave_stmts(funcdef.body)

View File

@ -3,6 +3,7 @@ from llvm import passes as lp
from artiq.compiler import ir_infer_types, ir_ast_body, ir_values from artiq.compiler import ir_infer_types, ir_ast_body, ir_values
def compile_function(module, env, funcdef): def compile_function(module, env, funcdef):
function_type = lc.Type.function(lc.Type.void(), []) function_type = lc.Type.function(lc.Type.void(), [])
function = module.add_function(function_type, funcdef.name) function = module.add_function(function_type, funcdef.name)
@ -16,6 +17,7 @@ def compile_function(module, env, funcdef):
visitor.visit_statements(funcdef.body) visitor.visit_statements(funcdef.body)
builder.ret_void() builder.ret_void()
def get_runtime_binary(env, funcdef): def get_runtime_binary(env, funcdef):
module = lc.Module.new("main") module = lc.Module.new("main")
env.init_module(module) env.init_module(module)

View File

@ -2,6 +2,7 @@ import ast
from artiq.compiler import ir_values from artiq.compiler import ir_values
class Visitor: class Visitor:
def __init__(self, env, ns, builder=None): def __init__(self, env, ns, builder=None):
self.env = env self.env = env
@ -14,7 +15,8 @@ class Visitor:
try: try:
visitor = getattr(self, method) visitor = getattr(self, method)
except AttributeError: except AttributeError:
raise NotImplementedError("Unsupported node '{}' in expression".format(node.__class__.__name__)) raise NotImplementedError("Unsupported node '{}' in expression"
.format(node.__class__.__name__))
return visitor(node) return visitor(node)
def _visit_expr_Name(self, node): def _visit_expr_Name(self, node):
@ -56,7 +58,8 @@ class Visitor:
ast.UAdd: ir_values.operators.pos, ast.UAdd: ir_values.operators.pos,
ast.USub: ir_values.operators.neg ast.USub: ir_values.operators.neg
} }
return ast_unops[type(node.op)](self.visit_expression(node.operand), self.builder) return ast_unops[type(node.op)](self.visit_expression(node.operand),
self.builder)
def _visit_expr_BinOp(self, node): def _visit_expr_BinOp(self, node):
ast_binops = { ast_binops = {
@ -73,7 +76,9 @@ class Visitor:
ast.BitXor: ir_values.operators.xor, ast.BitXor: ir_values.operators.xor,
ast.BitAnd: ir_values.operators.and_ ast.BitAnd: ir_values.operators.and_
} }
return ast_binops[type(node.op)](self.visit_expression(node.left), self.visit_expression(node.right), self.builder) return ast_binops[type(node.op)](self.visit_expression(node.left),
self.visit_expression(node.right),
self.builder)
def _visit_expr_Compare(self, node): def _visit_expr_Compare(self, node):
ast_cmps = { ast_cmps = {
@ -88,7 +93,8 @@ class Visitor:
old_comparator = self.visit_expression(node.left) old_comparator = self.visit_expression(node.left)
for op, comparator_a in zip(node.ops, node.comparators): for op, comparator_a in zip(node.ops, node.comparators):
comparator = self.visit_expression(comparator_a) comparator = self.visit_expression(comparator_a)
comparison = ast_cmps[type(op)](old_comparator, comparator, self.builder) comparison = ast_cmps[type(op)](old_comparator, comparator,
self.builder)
comparisons.append(comparison) comparisons.append(comparison)
old_comparator = comparator old_comparator = comparator
r = comparisons[0] r = comparisons[0]
@ -106,7 +112,8 @@ class Visitor:
} }
fn = node.func.id fn = node.func.id
if fn in ast_unfuns: if fn in ast_unfuns:
return ast_unfuns[fn](self.visit_expression(node.args[0]), self.builder) return ast_unfuns[fn](self.visit_expression(node.args[0]),
self.builder)
elif fn == "Fraction": elif fn == "Fraction":
r = ir_values.VFraction() r = ir_values.VFraction()
if self.builder is not None: if self.builder is not None:
@ -115,7 +122,8 @@ class Visitor:
r.set_value_nd(self.builder, numerator, denominator) r.set_value_nd(self.builder, numerator, denominator)
return r return r
elif fn == "syscall": elif fn == "syscall":
return self.env.syscall(node.args[0].s, return self.env.syscall(
node.args[0].s,
[self.visit_expression(expr) for expr in node.args[1:]], [self.visit_expression(expr) for expr in node.args[1:]],
self.builder) self.builder)
else: else:
@ -127,7 +135,8 @@ class Visitor:
try: try:
visitor = getattr(self, method) visitor = getattr(self, method)
except AttributeError: except AttributeError:
raise NotImplementedError("Unsupported node '{}' in statement".format(node.__class__.__name__)) raise NotImplementedError("Unsupported node '{}' in statement"
.format(node.__class__.__name__))
visitor(node) visitor(node)
def _visit_stmt_Assign(self, node): def _visit_stmt_Assign(self, node):
@ -139,7 +148,8 @@ class Visitor:
raise NotImplementedError raise NotImplementedError
def _visit_stmt_AugAssign(self, node): def _visit_stmt_AugAssign(self, node):
val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) val = self.visit_expression(ast.BinOp(op=node.op, left=node.target,
right=node.value))
if isinstance(node.target, ast.Name): if isinstance(node.target, ast.Name):
self.ns[node.target.id].set_value(self.builder, val) self.ns[node.target.id].set_value(self.builder, val)
else: else:
@ -154,8 +164,10 @@ class Visitor:
else_block = function.append_basic_block("i_else") else_block = function.append_basic_block("i_else")
merge_block = function.append_basic_block("i_merge") merge_block = function.append_basic_block("i_merge")
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) condition = ir_values.operators.bool(self.visit_expression(node.test),
self.builder.cbranch(condition.get_ssa_value(self.builder), then_block, else_block) self.builder)
self.builder.cbranch(condition.get_ssa_value(self.builder),
then_block, else_block)
self.builder.position_at_end(then_block) self.builder.position_at_end(then_block)
self.visit_statements(node.body) self.visit_statements(node.body)
@ -173,13 +185,17 @@ class Visitor:
else_block = function.append_basic_block("w_else") else_block = function.append_basic_block("w_else")
merge_block = function.append_basic_block("w_merge") merge_block = function.append_basic_block("w_merge")
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) condition = ir_values.operators.bool(
self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, else_block) self.visit_expression(node.test), self.builder)
self.builder.cbranch(
condition.get_ssa_value(self.builder), body_block, else_block)
self.builder.position_at_end(body_block) self.builder.position_at_end(body_block)
self.visit_statements(node.body) self.visit_statements(node.body)
condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) condition = ir_values.operators.bool(
self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, merge_block) self.visit_expression(node.test), self.builder)
self.builder.cbranch(
condition.get_ssa_value(self.builder), body_block, merge_block)
self.builder.position_at_end(else_block) self.builder.position_at_end(else_block)
self.visit_statements(node.orelse) self.visit_statements(node.orelse)

View File

@ -4,6 +4,7 @@ from copy import deepcopy
from artiq.compiler.ir_ast_body import Visitor from artiq.compiler.ir_ast_body import Visitor
class _TypeScanner(ast.NodeVisitor): class _TypeScanner(ast.NodeVisitor):
def __init__(self, env, ns): def __init__(self, env, ns):
self.exprv = Visitor(env, ns) self.exprv = Visitor(env, ns)
@ -21,7 +22,8 @@ class _TypeScanner(ast.NodeVisitor):
raise NotImplementedError raise NotImplementedError
def visit_AugAssign(self, node): def visit_AugAssign(self, node):
val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) val = self.exprv.visit_expression(ast.BinOp(
op=node.op, left=node.target, right=node.value))
ns = self.exprv.ns ns = self.exprv.ns
target = node.target target = node.target
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
@ -32,6 +34,7 @@ class _TypeScanner(ast.NodeVisitor):
else: else:
raise NotImplementedError raise NotImplementedError
def infer_types(env, node): def infer_types(env, node):
ns = dict() ns = dict()
while True: while True:

View File

@ -2,6 +2,7 @@ from types import SimpleNamespace
from llvm import core as lc from llvm import core as lc
class _Value: class _Value:
def __init__(self): def __init__(self):
self._llvm_value = None self._llvm_value = None
@ -18,7 +19,8 @@ class _Value:
elif isinstance(self._llvm_value, lc.AllocaInstruction): elif isinstance(self._llvm_value, lc.AllocaInstruction):
builder.store(value, self._llvm_value) builder.store(value, self._llvm_value)
else: else:
raise RuntimeError("Attempted to set LLVM SSA value multiple times") raise RuntimeError(
"Attempted to set LLVM SSA value multiple times")
def alloca(self, builder, name): def alloca(self, builder, name):
if self._llvm_value is not None: if self._llvm_value is not None:
@ -37,6 +39,7 @@ class _Value:
def o_round64(self, builder): def o_round64(self, builder):
return self.o_roundx(64, builder) return self.o_roundx(64, builder)
# None type # None type
class VNone(_Value): class VNone(_Value):
@ -62,6 +65,7 @@ class VNone(_Value):
r.set_const_value(builder, False) r.set_const_value(builder, False)
return r return r
# Integer type # Integer type
class VInt(_Value): class VInt(_Value):
@ -86,7 +90,8 @@ class VInt(_Value):
raise TypeError raise TypeError
def set_value(self, builder, n): def set_value(self, builder, n):
self.set_ssa_value(builder, n.o_intx(self.nbits, builder).get_ssa_value(builder)) self.set_ssa_value(
builder, n.o_intx(self.nbits, builder).get_ssa_value(builder))
def set_const_value(self, builder, n): def set_const_value(self, builder, n):
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n)) self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))
@ -94,22 +99,31 @@ class VInt(_Value):
def o_bool(self, builder): def o_bool(self, builder):
r = VBool() r = VBool()
if builder is not None: if builder is not None:
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, r.set_ssa_value(
self.get_ssa_value(builder), lc.Constant.int(self.get_llvm_type(), 0))) builder, builder.icmp(
lc.ICMP_NE,
self.get_ssa_value(builder),
lc.Constant.int(self.get_llvm_type(), 0)))
return r return r
def o_intx(self, target_bits, builder): def o_intx(self, target_bits, builder):
r = VInt(target_bits) r = VInt(target_bits)
if builder is not None: if builder is not None:
if self.nbits == target_bits: if self.nbits == target_bits:
r.set_ssa_value(builder, self.get_ssa_value(builder)) r.set_ssa_value(
builder, self.get_ssa_value(builder))
if self.nbits > target_bits: if self.nbits > target_bits:
r.set_ssa_value(builder, builder.trunc(self.get_ssa_value(builder), r.get_llvm_type())) r.set_ssa_value(
builder, builder.trunc(self.get_ssa_value(builder),
r.get_llvm_type()))
if self.nbits < target_bits: if self.nbits < target_bits:
r.set_ssa_value(builder, builder.sext(self.get_ssa_value(builder), r.get_llvm_type())) r.set_ssa_value(
builder, builder.sext(self.get_ssa_value(builder),
r.get_llvm_type()))
return r return r
o_roundx = o_intx o_roundx = o_intx
def _make_vint_binop_method(builder_name): def _make_vint_binop_method(builder_name):
def binop_method(self, other, builder): def binop_method(self, other, builder):
if isinstance(other, VInt): if isinstance(other, VInt):
@ -119,15 +133,15 @@ def _make_vint_binop_method(builder_name):
left = self.o_intx(target_bits, builder) left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder) right = other.o_intx(target_bits, builder)
bf = getattr(builder, builder_name) bf = getattr(builder, builder_name)
r.set_ssa_value(builder, r.set_ssa_value(
bf(left.get_ssa_value(builder), right.get_ssa_value(builder))) builder, bf(left.get_ssa_value(builder),
right.get_ssa_value(builder)))
return r return r
else: else:
return NotImplemented return NotImplemented
return binop_method return binop_method
for _method_name, _builder_name in ( for _method_name, _builder_name in (("o_add", "add"),
("o_add", "add"),
("o_sub", "sub"), ("o_sub", "sub"),
("o_mul", "mul"), ("o_mul", "mul"),
("o_floordiv", "sdiv"), ("o_floordiv", "sdiv"),
@ -137,6 +151,7 @@ for _method_name, _builder_name in (
("o_or", "or_")): ("o_or", "or_")):
setattr(VInt, _method_name, _make_vint_binop_method(_builder_name)) setattr(VInt, _method_name, _make_vint_binop_method(_builder_name))
def _make_vint_cmp_method(icmp_val): def _make_vint_cmp_method(icmp_val):
def cmp_method(self, other, builder): def cmp_method(self, other, builder):
if isinstance(other, VInt): if isinstance(other, VInt):
@ -145,15 +160,17 @@ def _make_vint_cmp_method(icmp_val):
target_bits = max(self.nbits, other.nbits) target_bits = max(self.nbits, other.nbits)
left = self.o_intx(target_bits, builder) left = self.o_intx(target_bits, builder)
right = other.o_intx(target_bits, builder) right = other.o_intx(target_bits, builder)
r.set_ssa_value(builder, r.set_ssa_value(
builder.icmp(icmp_val, left.get_ssa_value(builder), right.get_ssa_value(builder))) builder,
builder.icmp(
icmp_val, left.get_ssa_value(builder),
right.get_ssa_value(builder)))
return r return r
else: else:
return NotImplemented return NotImplemented
return cmp_method return cmp_method
for _method_name, _icmp_val in ( for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ),
("o_eq", lc.ICMP_EQ),
("o_ne", lc.ICMP_NE), ("o_ne", lc.ICMP_NE),
("o_lt", lc.ICMP_SLT), ("o_lt", lc.ICMP_SLT),
("o_le", lc.ICMP_SLE), ("o_le", lc.ICMP_SLE),
@ -161,6 +178,7 @@ for _method_name, _icmp_val in (
("o_ge", lc.ICMP_SGE)): ("o_ge", lc.ICMP_SGE)):
setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val)) setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val))
# Boolean type # Boolean type
class VBool(VInt): class VBool(VInt):
@ -186,24 +204,30 @@ class VBool(VInt):
r.set_ssa_value(builder, self.get_ssa_value(builder)) r.set_ssa_value(builder, self.get_ssa_value(builder))
return r return r
# Fraction type # Fraction type
def _gcd64(builder, a, b): def _gcd64(builder, a, b):
gcd_f = builder.module.get_function_named("__gcd64") gcd_f = builder.module.get_function_named("__gcd64")
return builder.call(gcd_f, [a, b]) return builder.call(gcd_f, [a, b])
def _frac_normalize(builder, numerator, denominator): def _frac_normalize(builder, numerator, denominator):
gcd = _gcd64(numerator, denominator) gcd = _gcd64(numerator, denominator)
numerator = builder.sdiv(numerator, gcd) numerator = builder.sdiv(numerator, gcd)
denominator = builder.sdiv(denominator, gcd) denominator = builder.sdiv(denominator, gcd)
return numerator, denominator return numerator, denominator
def _frac_make_ssa(builder, numerator, denominator): def _frac_make_ssa(builder, numerator, denominator):
value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2)) value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2))
value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0)) value = builder.insert_element(
value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1)) value, numerator, lc.Constant.int(lc.Type.int(), 0))
value = builder.insert_element(
value, denominator, lc.Constant.int(lc.Type.int(), 1))
return value return value
class VFraction(_Value): class VFraction(_Value):
def get_llvm_type(self): def get_llvm_type(self):
return lc.Type.vector(lc.Type.int(64), 2) return lc.Type.vector(lc.Type.int(64), 2)
@ -220,8 +244,10 @@ class VFraction(_Value):
def _nd(self, builder, invert=False): def _nd(self, builder, invert=False):
ssa_value = self.get_ssa_value(builder) ssa_value = self.get_ssa_value(builder)
numerator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 0)) numerator = builder.extract_element(
denominator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 1)) ssa_value, lc.Constant.int(lc.Type.int(), 0))
denominator = builder.extract_element(
ssa_value, lc.Constant.int(lc.Type.int(), 1))
if invert: if invert:
return denominator, numerator return denominator, numerator
else: else:
@ -230,8 +256,10 @@ class VFraction(_Value):
def set_value_nd(self, builder, numerator, denominator): def set_value_nd(self, builder, numerator, denominator):
numerator = numerator.o_int64(builder).get_ssa_value(builder) numerator = numerator.o_int64(builder).get_ssa_value(builder)
denominator = denominator.o_int64(builder).get_ssa_value(builder) denominator = denominator.o_int64(builder).get_ssa_value(builder)
numerator, denominator = _frac_normalize(builder, numerator, denominator) numerator, denominator = _frac_normalize(
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator)) builder, numerator, denominator)
self.set_ssa_value(
builder, _frac_make_ssa(builder, numerator, denominator))
def set_value(self, builder, n): def set_value(self, builder, n):
if not isinstance(n, VFraction): if not isinstance(n, VFraction):
@ -242,7 +270,8 @@ class VFraction(_Value):
r = VBool() r = VBool()
if builder is not None: if builder is not None:
zero = lc.Constant.int(lc.Type.int(64), 0) zero = lc.Constant.int(lc.Type.int(64), 0)
numerator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0)) numerator = builder.extract_element(
self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0))
r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero)) r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero))
return r return r
@ -261,7 +290,8 @@ class VFraction(_Value):
else: else:
r = VInt(64) r = VInt(64)
numerator, denominator = self._nd(builder) numerator, denominator = self._nd(builder)
h_denominator = builder.ashr(denominator, lc.Constant.int(lc.Type.int(), 1)) h_denominator = builder.ashr(denominator,
lc.Constant.int(lc.Type.int(), 1))
r_numerator = builder.add(numerator, h_denominator) r_numerator = builder.add(numerator, h_denominator)
r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator)) r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator))
return r.o_intx(target_bits, builder) return r.o_intx(target_bits, builder)
@ -272,12 +302,17 @@ class VFraction(_Value):
if builder is not None: if builder is not None:
ee = [] ee = []
for i in range(2): for i in range(2):
es = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), i)) es = builder.extract_element(
eo = builder.extract_element(other.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), i)) self.get_ssa_value(builder),
lc.Constant.int(lc.Type.int(), i))
eo = builder.extract_element(
other.get_ssa_value(builder),
lc.Constant.int(lc.Type.int(), i))
ee.append(builder.icmp(lc.ICMP_EQ, es, eo)) ee.append(builder.icmp(lc.ICMP_EQ, es, eo))
ssa_r = builder.and_(ee[0], ee[1]) ssa_r = builder.and_(ee[0], ee[1])
if ne: if ne:
ssa_r = builder.xor(ssa_r, lc.Constant.int(lc.Type.int(1), 1)) ssa_r = builder.xor(ssa_r,
lc.Constant.int(lc.Type.int(1), 1))
r.set_ssa_value(builder, ssa_r) r.set_ssa_value(builder, ssa_r)
return r return r
else: else:
@ -307,7 +342,8 @@ class VFraction(_Value):
i = builder.sdiv(i, gcd) i = builder.sdiv(i, gcd)
denominator = builder.sdiv(denominator, gcd) denominator = builder.sdiv(denominator, gcd)
numerator = builder.mul(numerator, i) numerator = builder.mul(numerator, i)
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator)) self.set_ssa_value(builder, _frac_make_ssa(builder, numerator,
denominator))
elif isinstance(other, VFraction): elif isinstance(other, VFraction):
if builder is None: if builder is None:
return r return r
@ -320,8 +356,10 @@ class VFraction(_Value):
else: else:
numerator = builder.mul(numerator, onumerator) numerator = builder.mul(numerator, onumerator)
denominator = builder.mul(denominator, odenominator) denominator = builder.mul(denominator, odenominator)
numerator, denominator = _frac_normalize(builder, numerator, denominator) numerator, denominator = _frac_normalize(builder, numerator,
self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator)) denominator)
self.set_ssa_value(
builder, _frac_make_ssa(builder, numerator, denominator))
else: else:
return NotImplemented return NotImplemented
@ -351,6 +389,7 @@ class VFraction(_Value):
else: else:
return r.o_int(builder) return r.o_int(builder)
# Operators # Operators
def _make_unary_operator(op_name): def _make_unary_operator(op_name):
@ -358,10 +397,13 @@ def _make_unary_operator(op_name):
try: try:
opf = getattr(x, "o_"+op_name) opf = getattr(x, "o_"+op_name)
except AttributeError: except AttributeError:
raise TypeError("Unsupported operand type for {}: {}".format(op_name, type(x).__name__)) raise TypeError(
"Unsupported operand type for {}: {}"
.format(op_name, type(x).__name__))
return opf(builder) return opf(builder)
return op return op
def _make_binary_operator(op_name): def _make_binary_operator(op_name):
def op(l, r, builder): def op(l, r, builder):
try: try:
@ -378,14 +420,17 @@ def _make_binary_operator(op_name):
else: else:
result = ropf(l, builder) result = ropf(l, builder)
if result is NotImplemented: if result is NotImplemented:
raise TypeError("Unsupported operand types for {}: {} and {}".format( raise TypeError(
op_name, type(l).__name__, type(r).__name__)) "Unsupported operand types for {}: {} and {}"
.format(op_name, type(l).__name__, type(r).__name__))
return result return result
return op return op
def _make_operators(): def _make_operators():
d = dict() d = dict()
for op_name in ("bool", "int", "int64", "round", "round64", "inv", "pos", "neg"): for op_name in ("bool", "int", "int64", "round", "round64",
"inv", "pos", "neg"):
d[op_name] = _make_unary_operator(op_name) d[op_name] = _make_unary_operator(op_name)
d["not_"] = _make_binary_operator("not") d["not_"] = _make_binary_operator("not")
for op_name in ("add", "sub", "mul", for op_name in ("add", "sub", "mul",
@ -399,7 +444,8 @@ def _make_operators():
operators = _make_operators() operators = _make_operators()
def init_module(module): def init_module(module):
func_type = lc.Type.function(lc.Type.int(64), func_type = lc.Type.function(
[lc.Type.int(64), lc.Type.int(64)]) lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)])
module.add_function(func_type, "__gcd64") module.add_function(func_type, "__gcd64")

View File

@ -3,11 +3,14 @@ import ast
from artiq.compiler.tools import value_to_ast from artiq.compiler.tools import value_to_ast
from artiq.language.core import int64 from artiq.language.core import int64
def _insert_int64(node): def _insert_int64(node):
return ast.copy_location( return ast.copy_location(
ast.Call(func=ast.Name("int64", ast.Load()), ast.Call(func=ast.Name("int64", ast.Load()),
args=[node], args=[node],
keywords=[], starargs=[], kwargs=[]), node) keywords=[], starargs=[], kwargs=[]),
node)
class _TimeLowerer(ast.NodeTransformer): class _TimeLowerer(ast.NodeTransformer):
def visit_Call(self, node): def visit_Call(self, node):
@ -19,11 +22,13 @@ class _TimeLowerer(ast.NodeTransformer):
def visit_Expr(self, node): def visit_Expr(self, node):
self.generic_visit(node) self.generic_visit(node)
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name): if (isinstance(node.value, ast.Call)
and isinstance(node.value.func, ast.Name)):
funcname = node.value.func.id funcname = node.value.func.id
if funcname == "delay": if funcname == "delay":
return ast.copy_location( return ast.copy_location(
ast.AugAssign(target=ast.Name("now", ast.Store()), op=ast.Add(), ast.AugAssign(target=ast.Name("now", ast.Store()),
op=ast.Add(),
value=_insert_int64(node.value.args[0])), value=_insert_int64(node.value.args[0])),
node) node)
elif funcname == "at": elif funcname == "at":
@ -36,8 +41,10 @@ class _TimeLowerer(ast.NodeTransformer):
else: else:
return node return node
def lower_time(funcdef, initial_time): def lower_time(funcdef, initial_time):
_TimeLowerer().visit(funcdef) _TimeLowerer().visit(funcdef)
funcdef.body.insert(0, ast.copy_location( funcdef.body.insert(0, ast.copy_location(
ast.Assign(targets=[ast.Name("now", ast.Store())], value=value_to_ast(int64(initial_time))), ast.Assign(targets=[ast.Name("now", ast.Store())],
value=value_to_ast(int64(initial_time))),
funcdef)) funcdef))

View File

@ -3,6 +3,7 @@ import ast
from artiq.compiler.tools import value_to_ast from artiq.compiler.tools import value_to_ast
from artiq.language import units from artiq.language import units
# TODO: # TODO:
# * track variable and expression dimensions # * track variable and expression dimensions
# * raise exception on dimension errors in expressions # * raise exception on dimension errors in expressions
@ -28,7 +29,9 @@ class _UnitsLowerer(ast.NodeTransformer):
node = node.args[0] node = node.args[0]
else: else:
node = ast.copy_location( node = ast.copy_location(
ast.BinOp(left=node.args[0], op=ast.Div(), right=value_to_ast(self.ref_period)), ast.BinOp(left=node.args[0],
op=ast.Div(),
right=value_to_ast(self.ref_period)),
node) node)
else: else:
node = node.args[0] node = node.args[0]
@ -36,7 +39,9 @@ class _UnitsLowerer(ast.NodeTransformer):
self.generic_visit(node) self.generic_visit(node)
return node return node
def lower_units(funcdef, ref_period): def lower_units(funcdef, ref_period):
if not isinstance(ref_period, units.Quantity) or ref_period.unit is not units.s_unit: if (not isinstance(ref_period, units.Quantity)
or ref_period.unit is not units.s_unit):
raise units.DimensionError("Reference period not expressed in seconds") raise units.DimensionError("Reference period not expressed in seconds")
_UnitsLowerer(ref_period.amount).visit(funcdef) _UnitsLowerer(ref_period.amount).visit(funcdef)

View File

@ -4,6 +4,7 @@ from fractions import Fraction
from artiq.language import core as core_language from artiq.language import core as core_language
from artiq.language import units from artiq.language import units
def eval_ast(expr, symdict=dict()): def eval_ast(expr, symdict=dict()):
if not isinstance(expr, ast.Expression): if not isinstance(expr, ast.Expression):
expr = ast.copy_location(ast.Expression(expr), expr) expr = ast.copy_location(ast.Expression(expr), expr)
@ -11,6 +12,7 @@ def eval_ast(expr, symdict=dict()):
code = compile(expr, "<ast>", "eval") code = compile(expr, "<ast>", "eval")
return eval(code, symdict) return eval(code, symdict)
def value_to_ast(value): def value_to_ast(value):
if isinstance(value, core_language.int64): # must be before int if isinstance(value, core_language.int64): # must be before int
return ast.Call( return ast.Call(
@ -20,7 +22,8 @@ def value_to_ast(value):
elif isinstance(value, int): elif isinstance(value, int):
return ast.Num(value) return ast.Num(value)
elif isinstance(value, Fraction): elif isinstance(value, Fraction):
return ast.Call(func=ast.Name("Fraction", ast.Load()), return ast.Call(
func=ast.Name("Fraction", ast.Load()),
args=[ast.Num(value.numerator), ast.Num(value.denominator)], args=[ast.Num(value.numerator), ast.Num(value.denominator)],
keywords=[], starargs=None, kwargs=None) keywords=[], starargs=None, kwargs=None)
elif isinstance(value, str): elif isinstance(value, str):
@ -32,13 +35,16 @@ def value_to_ast(value):
if isinstance(value, units.Quantity): if isinstance(value, units.Quantity):
return ast.Call( return ast.Call(
func=ast.Name("Quantity", ast.Load()), func=ast.Name("Quantity", ast.Load()),
args=[value_to_ast(value.amount), ast.Name(value.unit.name+"_unit", ast.Load())], args=[value_to_ast(value.amount),
ast.Name(value.unit.name+"_unit", ast.Load())],
keywords=[], starargs=None, kwargs=None) keywords=[], starargs=None, kwargs=None)
return None return None
class NotConstant(Exception): class NotConstant(Exception):
pass pass
def eval_constant(node): def eval_constant(node):
if isinstance(node, ast.Num): if isinstance(node, ast.Num):
return node.n return node.n
@ -47,7 +53,8 @@ def eval_constant(node):
elif isinstance(node, ast.Call): elif isinstance(node, ast.Call):
funcname = node.func.id funcname = node.func.id
if funcname == "Fraction": if funcname == "Fraction":
numerator, denominator = eval_constant(node.args[0]), eval_constant(node.args[1]) numerator = eval_constant(node.args[0])
denominator = eval_constant(node.args[1])
return Fraction(numerator, denominator) return Fraction(numerator, denominator)
elif funcname == "Quantity": elif funcname == "Quantity":
amount, unit = node.args amount, unit = node.args

View File

@ -1,11 +1,12 @@
import sys import sys
import ast import ast
import os
# Large float and imaginary literals get turned into infinities in the AST. # Large float and imaginary literals get turned into infinities in the AST.
# We unparse those infinities to INFSTR. # We unparse those infinities to INFSTR.
INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)
def interleave(inter, f, seq): def interleave(inter, f, seq):
"""Call f on each item in seq, calling inter() in between. """Call f on each item in seq, calling inter() in between.
""" """
@ -19,12 +20,13 @@ def interleave(inter, f, seq):
inter() inter()
f(x) f(x)
class Unparser: class Unparser:
"""Methods in this class recursively traverse an AST and """Methods in this class recursively traverse an AST and
output source code for the abstract syntax; original formatting output source code for the abstract syntax; original formatting
is disregarded. """ is disregarded. """
def __init__(self, tree, file = sys.stdout): def __init__(self, tree, file=sys.stdout):
"""Unparser(tree, file=sys.stdout) -> None. """Unparser(tree, file=sys.stdout) -> None.
Print the source for tree to file.""" Print the source for tree to file."""
self.f = file self.f = file
@ -33,7 +35,7 @@ class Unparser:
print("", file=self.f) print("", file=self.f)
self.f.flush() self.f.flush()
def fill(self, text = ""): def fill(self, text=""):
"Indent a piece of text, according to the current indentation level" "Indent a piece of text, according to the current indentation level"
self.f.write("\n"+" "*self._indent + text) self.f.write("\n"+" "*self._indent + text)
@ -59,13 +61,12 @@ class Unparser:
meth = getattr(self, "_"+tree.__class__.__name__) meth = getattr(self, "_"+tree.__class__.__name__)
meth(tree) meth(tree)
# Unparsing methods
############### Unparsing methods ###################### #
# There should be one method per concrete grammar type # # There should be one method per concrete grammar type
# Constructors should be grouped by sum type. Ideally, # # Constructors should be grouped by sum type. Ideally,
# this would follow the order in the grammar, but # # this would follow the order in the grammar, but
# currently doesn't. # # currently doesn't.
########################################################
def _Module(self, tree): def _Module(self, tree):
for stmt in tree.body: for stmt in tree.body:
@ -201,21 +202,29 @@ class Unparser:
self.write("(") self.write("(")
comma = False comma = False
for e in t.bases: for e in t.bases:
if comma: self.write(", ") if comma:
else: comma = True self.write(", ")
else:
comma = True
self.dispatch(e) self.dispatch(e)
for e in t.keywords: for e in t.keywords:
if comma: self.write(", ") if comma:
else: comma = True self.write(", ")
else:
comma = True
self.dispatch(e) self.dispatch(e)
if t.starargs: if t.starargs:
if comma: self.write(", ") if comma:
else: comma = True self.write(", ")
else:
comma = True
self.write("*") self.write("*")
self.dispatch(t.starargs) self.dispatch(t.starargs)
if t.kwargs: if t.kwargs:
if comma: self.write(", ") if comma:
else: comma = True self.write(", ")
else:
comma = True
self.write("**") self.write("**")
self.dispatch(t.kwargs) self.dispatch(t.kwargs)
self.write(")") self.write(")")
@ -372,6 +381,7 @@ class Unparser:
def _Dict(self, t): def _Dict(self, t):
self.write("{") self.write("{")
def write_pair(pair): def write_pair(pair):
(k, v) = pair (k, v) = pair
self.dispatch(k) self.dispatch(k)
@ -390,7 +400,8 @@ class Unparser:
interleave(lambda: self.write(", "), self.dispatch, t.elts) interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write(")") self.write(")")
unop = {"Invert":"~", "Not": "not", "UAdd":"+", "USub":"-"} unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"}
def _UnaryOp(self, t): def _UnaryOp(self, t):
self.write("(") self.write("(")
self.write(self.unop[t.op.__class__.__name__]) self.write(self.unop[t.op.__class__.__name__])
@ -398,9 +409,11 @@ class Unparser:
self.dispatch(t.operand) self.dispatch(t.operand)
self.write(")") self.write(")")
binop = { "Add":"+", "Sub":"-", "Mult":"*", "Div":"/", "Mod":"%", binop = {"Add": "+", "Sub": "-", "Mult": "*", "Div": "/", "Mod": "%",
"LShift":"<<", "RShift":">>", "BitOr":"|", "BitXor":"^", "BitAnd":"&", "LShift": "<<", "RShift": ">>",
"FloorDiv":"//", "Pow": "**"} "BitOr": "|", "BitXor": "^", "BitAnd": "&",
"FloorDiv": "//", "Pow": "**"}
def _BinOp(self, t): def _BinOp(self, t):
self.write("(") self.write("(")
self.dispatch(t.left) self.dispatch(t.left)
@ -408,8 +421,10 @@ class Unparser:
self.dispatch(t.right) self.dispatch(t.right)
self.write(")") self.write(")")
cmpops = {"Eq":"==", "NotEq":"!=", "Lt":"<", "LtE":"<=", "Gt":">", "GtE":">=", cmpops = {"Eq": "==", "NotEq": "!=",
"Is":"is", "IsNot":"is not", "In":"in", "NotIn":"not in"} "Lt": "<", "LtE": "<=", "Gt": ">", "GtE": ">=",
"Is": "is", "IsNot": "is not", "In": "in", "NotIn": "not in"}
def _Compare(self, t): def _Compare(self, t):
self.write("(") self.write("(")
self.dispatch(t.left) self.dispatch(t.left)
@ -418,14 +433,15 @@ class Unparser:
self.dispatch(e) self.dispatch(e)
self.write(")") self.write(")")
boolops = {ast.And: 'and', ast.Or: 'or'} boolops = {ast.And: "and", ast.Or: "or"}
def _BoolOp(self, t): def _BoolOp(self, t):
self.write("(") self.write("(")
s = " %s " % self.boolops[t.op.__class__] s = " %s " % self.boolops[t.op.__class__]
interleave(lambda: self.write(s), self.dispatch, t.values) interleave(lambda: self.write(s), self.dispatch, t.values)
self.write(")") self.write(")")
def _Attribute(self,t): def _Attribute(self, t):
self.dispatch(t.value) self.dispatch(t.value)
# Special case: 3.__abs__() is a syntax error, so if t.value # Special case: 3.__abs__() is a syntax error, so if t.value
# is an integer literal then we need to either parenthesize # is an integer literal then we need to either parenthesize
@ -440,21 +456,29 @@ class Unparser:
self.write("(") self.write("(")
comma = False comma = False
for e in t.args: for e in t.args:
if comma: self.write(", ") if comma:
else: comma = True self.write(", ")
else:
comma = True
self.dispatch(e) self.dispatch(e)
for e in t.keywords: for e in t.keywords:
if comma: self.write(", ") if comma:
else: comma = True self.write(", ")
else:
comma = True
self.dispatch(e) self.dispatch(e)
if t.starargs: if t.starargs:
if comma: self.write(", ") if comma:
else: comma = True self.write(", ")
else:
comma = True
self.write("*") self.write("*")
self.dispatch(t.starargs) self.dispatch(t.starargs)
if t.kwargs: if t.kwargs:
if comma: self.write(", ") if comma:
else: comma = True self.write(", ")
else:
comma = True
self.write("**") self.write("**")
self.dispatch(t.kwargs) self.dispatch(t.kwargs)
self.write(")") self.write(")")
@ -502,8 +526,10 @@ class Unparser:
# normal arguments # normal arguments
defaults = [None] * (len(t.args) - len(t.defaults)) + t.defaults defaults = [None] * (len(t.args) - len(t.defaults)) + t.defaults
for a, d in zip(t.args, defaults): for a, d in zip(t.args, defaults):
if first:first = False if first:
else: self.write(", ") first = False
else:
self.write(", ")
self.dispatch(a) self.dispatch(a)
if d: if d:
self.write("=") self.write("=")
@ -511,8 +537,10 @@ class Unparser:
# varargs, or bare '*' if no varargs but keyword-only arguments present # varargs, or bare '*' if no varargs but keyword-only arguments present
if t.vararg or t.kwonlyargs: if t.vararg or t.kwonlyargs:
if first:first = False if first:
else: self.write(", ") first = False
else:
self.write(", ")
self.write("*") self.write("*")
if t.vararg: if t.vararg:
self.write(t.vararg.arg) self.write(t.vararg.arg)
@ -523,8 +551,10 @@ class Unparser:
# keyword-only arguments # keyword-only arguments
if t.kwonlyargs: if t.kwonlyargs:
for a, d in zip(t.kwonlyargs, t.kw_defaults): for a, d in zip(t.kwonlyargs, t.kw_defaults):
if first:first = False if first:
else: self.write(", ") first = False
else:
self.write(", ")
self.dispatch(a), self.dispatch(a),
if d: if d:
self.write("=") self.write("=")
@ -532,8 +562,10 @@ class Unparser:
# kwargs # kwargs
if t.kwarg: if t.kwarg:
if first:first = False if first:
else: self.write(", ") first = False
else:
self.write(", ")
self.write("**"+t.kwarg.arg) self.write("**"+t.kwarg.arg)
if t.kwarg.annotation: if t.kwarg.annotation:
self.write(": ") self.write(": ")

View File

@ -2,6 +2,7 @@ import ast
from artiq.compiler.tools import eval_ast, value_to_ast from artiq.compiler.tools import eval_ast, value_to_ast
def _count_stmts(node): def _count_stmts(node):
if isinstance(node, (ast.For, ast.While, ast.If)): if isinstance(node, (ast.For, ast.While, ast.If)):
return 1 + _count_stmts(node.body) + _count_stmts(node.orelse) return 1 + _count_stmts(node.body) + _count_stmts(node.orelse)
@ -12,6 +13,7 @@ def _count_stmts(node):
else: else:
return 1 return 1
class _LoopUnroller(ast.NodeTransformer): class _LoopUnroller(ast.NodeTransformer):
def __init__(self, limit): def __init__(self, limit):
self.limit = limit self.limit = limit
@ -32,7 +34,9 @@ class _LoopUnroller(ast.NodeTransformer):
replacement = None replacement = None
break break
replacement.append(ast.copy_location( replacement.append(ast.copy_location(
ast.Assign(targets=[node.target], value=value_to_ast(i)), node)) ast.Assign(targets=[node.target],
value=value_to_ast(i)),
node))
replacement += node.body replacement += node.body
if replacement is not None: if replacement is not None:
return replacement return replacement
@ -43,5 +47,6 @@ class _LoopUnroller(ast.NodeTransformer):
else: else:
return node.orelse return node.orelse
def unroll_loops(node, limit): def unroll_loops(node, limit):
_LoopUnroller(limit).visit(node) _LoopUnroller(limit).visit(node)

View File

@ -6,6 +6,7 @@ from artiq.compiler.interleave import interleave
from artiq.compiler.lower_time import lower_time from artiq.compiler.lower_time import lower_time
from artiq.compiler.ir import get_runtime_binary from artiq.compiler.ir import get_runtime_binary
class Core: class Core:
def __init__(self, core_com, runtime_env=None): def __init__(self, core_com, runtime_env=None):
if runtime_env is None: if runtime_env is None:

View File

@ -3,6 +3,7 @@ from operator import itemgetter
from artiq.devices.runtime import LinkInterface from artiq.devices.runtime import LinkInterface
from artiq.language.units import ns from artiq.language.units import ns
class _RuntimeEnvironment(LinkInterface): class _RuntimeEnvironment(LinkInterface):
def __init__(self, ref_period): def __init__(self, ref_period):
self.ref_period = ref_period self.ref_period = ref_period
@ -10,6 +11,7 @@ class _RuntimeEnvironment(LinkInterface):
def emit_object(self): def emit_object(self):
return str(self.module) return str(self.module)
class CoreCom: class CoreCom:
def get_runtime_env(self): def get_runtime_env(self):
return _RuntimeEnvironment(10*ns) return _RuntimeEnvironment(10*ns)

View File

@ -1,18 +1,24 @@
import os, termios, struct, zlib import os
import termios
import struct
import zlib
from enum import Enum from enum import Enum
from artiq.language import units from artiq.language import units
from artiq.devices.runtime import Environment from artiq.devices.runtime import Environment
class UnsupportedDevice(Exception): class UnsupportedDevice(Exception):
pass pass
class _MsgType(Enum): class _MsgType(Enum):
REQUEST_IDENT = 0x01 REQUEST_IDENT = 0x01
LOAD_KERNEL = 0x02 LOAD_KERNEL = 0x02
KERNEL_FINISHED = 0x03 KERNEL_FINISHED = 0x03
RPC_REQUEST = 0x04 RPC_REQUEST = 0x04
def _write_exactly(f, data): def _write_exactly(f, data):
remaining = len(data) remaining = len(data)
pos = 0 pos = 0
@ -21,12 +27,14 @@ def _write_exactly(f, data):
remaining -= written remaining -= written
pos += written pos += written
def _read_exactly(f, n): def _read_exactly(f, n):
r = bytes() r = bytes()
while(len(r) < n): while(len(r) < n):
r += f.read(n - len(r)) r += f.read(n - len(r))
return r return r
class CoreCom: class CoreCom:
def __init__(self, dev="/dev/ttyUSB1", baud=115200): def __init__(self, dev="/dev/ttyUSB1", baud=115200):
self._fd = os.open(dev, os.O_RDWR | os.O_NOCTTY) self._fd = os.open(dev, os.O_RDWR | os.O_NOCTTY)
@ -56,8 +64,10 @@ class CoreCom:
self.close() self.close()
def get_runtime_env(self): def get_runtime_env(self):
_write_exactly(self.port, struct.pack(">lb", 0x5a5a5a5a, _MsgType.REQUEST_IDENT.value)) _write_exactly(self.port, struct.pack(
# FIXME: when loading immediately after a board reset, we erroneously get some zeros back. ">lb", 0x5a5a5a5a, _MsgType.REQUEST_IDENT.value))
# FIXME: when loading immediately after a board reset,
# we erroneously get some zeros back.
# Ignore them with a warning for now. # Ignore them with a warning for now.
spurious_zero_count = 0 spurious_zero_count = 0
while True: while True:
@ -67,7 +77,8 @@ class CoreCom:
else: else:
break break
if spurious_zero_count: if spurious_zero_count:
print("Warning: received {} spurious zeros".format(spurious_zero_count)) print("Warning: received {} spurious zeros"
.format(spurious_zero_count))
runtime_id = chr(reply) runtime_id = chr(reply)
for i in range(3): for i in range(3):
(reply, ) = struct.unpack("b", _read_exactly(self.port, 1)) (reply, ) = struct.unpack("b", _read_exactly(self.port, 1))
@ -78,8 +89,10 @@ class CoreCom:
return Environment(ref_period*units.ps) return Environment(ref_period*units.ps)
def run(self, kcode): def run(self, kcode):
_write_exactly(self.port, struct.pack(">lblL", _write_exactly(self.port, struct.pack(
0x5a5a5a5a, _MsgType.LOAD_KERNEL.value, len(kcode), zlib.crc32(kcode))) ">lblL",
0x5a5a5a5a, _MsgType.LOAD_KERNEL.value,
len(kcode), zlib.crc32(kcode)))
_write_exactly(self.port, kcode) _write_exactly(self.port, kcode)
(reply, ) = struct.unpack("b", _read_exactly(self.port, 1)) (reply, ) = struct.unpack("b", _read_exactly(self.port, 1))
if reply != 0x4f: if reply != 0x4f:
@ -101,10 +114,12 @@ class CoreCom:
if msg == _MsgType.KERNEL_FINISHED: if msg == _MsgType.KERNEL_FINISHED:
return return
elif msg == _MsgType.RPC_REQUEST: elif msg == _MsgType.RPC_REQUEST:
rpc_num, n_args = struct.unpack(">hb", _read_exactly(self.port, 3)) rpc_num, n_args = struct.unpack(">hb",
_read_exactly(self.port, 3))
args = [] args = []
for i in range(n_args): for i in range(n_args):
args.append(*struct.unpack(">l", _read_exactly(self.port, 4))) args.append(*struct.unpack(">l",
_read_exactly(self.port, 4)))
r = rpc_map[rpc_num](*args) r = rpc_map[rpc_num](*args)
if r is None: if r is None:
r = 0 r = 0

View File

@ -1,6 +1,7 @@
from artiq.language.core import * from artiq.language.core import *
from artiq.language.units import * from artiq.language.units import *
class DDS(AutoContext): class DDS(AutoContext):
parameters = "dds_sysclk reg_channel rtio_channel" parameters = "dds_sysclk reg_channel rtio_channel"
@ -13,7 +14,8 @@ class DDS(AutoContext):
def pulse(self, frequency, duration): def pulse(self, frequency, duration):
if self._previous_frequency != frequency: if self._previous_frequency != frequency:
syscall("rtio_sync", self.rtio_channel) # wait until output is off syscall("rtio_sync", self.rtio_channel) # wait until output is off
syscall("dds_program", self.reg_channel, int(2**32*frequency/self.dds_sysclk)) syscall("dds_program", self.reg_channel,
int(2**32*frequency/self.dds_sysclk))
self._previous_frequency = frequency self._previous_frequency = frequency
syscall("rtio_set", now(), self.rtio_channel, 1) syscall("rtio_set", now(), self.rtio_channel, 1)
delay(duration) delay(duration)

View File

@ -1,5 +1,6 @@
from artiq.language.core import * from artiq.language.core import *
class GPIOOut(AutoContext): class GPIOOut(AutoContext):
parameters = "channel" parameters = "channel"

View File

@ -3,6 +3,7 @@ from llvm import target as lt
from artiq.compiler import ir_values from artiq.compiler import ir_values
lt.initialize_all() lt.initialize_all()
_syscalls = { _syscalls = {
@ -25,6 +26,7 @@ _chr_to_value = {
"I": lambda: ir_values.VInt(64) "I": lambda: ir_values.VInt(64)
} }
def _str_to_functype(s): def _str_to_functype(s):
assert(s[-2] == ":") assert(s[-2] == ":")
type_ret = _chr_to_type[s[-1]]() type_ret = _chr_to_type[s[-1]]()
@ -37,7 +39,10 @@ def _str_to_functype(s):
var_arg_fixcount = n var_arg_fixcount = n
else: else:
type_args.append(_chr_to_type[c]()) type_args.append(_chr_to_type[c]())
return var_arg_fixcount, lc.Type.function(type_ret, type_args, var_arg=var_arg_fixcount is not None) return (var_arg_fixcount,
lc.Type.function(type_ret, type_args,
var_arg=var_arg_fixcount is not None))
class LinkInterface: class LinkInterface:
def init_module(self, module): def init_module(self, module):
@ -58,10 +63,12 @@ class LinkInterface:
args = args[:fixcount] \ args = args[:fixcount] \
+ [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \ + [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \
+ args[fixcount:] + args[fixcount:]
llvm_function = self.module.get_function_named("__syscall_"+syscall_name) llvm_function = self.module.get_function_named(
"__syscall_" + syscall_name)
r.set_ssa_value(builder, builder.call(llvm_function, args)) r.set_ssa_value(builder, builder.call(llvm_function, args))
return r return r
class Environment(LinkInterface): class Environment(LinkInterface):
def __init__(self, ref_period): def __init__(self, ref_period):
self.ref_period = ref_period self.ref_period = ref_period

View File

@ -1,5 +1,6 @@
from artiq.language.core import * from artiq.language.core import *
class TTLOut(AutoContext): class TTLOut(AutoContext):
parameters = "channel" parameters = "channel"

View File

@ -3,6 +3,7 @@ from fractions import Fraction
from artiq.language import units from artiq.language import units
class int64(int): class int64(int):
pass pass
@ -14,8 +15,7 @@ def _make_int64_op_method(int_method):
return r return r
return method return method
for _op_name in ( for _op_name in ("neg", "pos", "abs", "invert", "round",
"neg", "pos", "abs", "invert", "round",
"add", "radd", "sub", "rsub", "mul", "rmul", "pow", "rpow", "add", "radd", "sub", "rsub", "mul", "rmul", "pow", "rpow",
"lshift", "rlshift", "rshift", "rrshift", "lshift", "rlshift", "rshift", "rrshift",
"and", "rand", "xor", "rxor", "or", "ror", "and", "rand", "xor", "rxor", "or", "ror",
@ -24,18 +24,21 @@ for _op_name in (
orig_method = getattr(int, method_name) orig_method = getattr(int, method_name)
setattr(int64, method_name, _make_int64_op_method(orig_method)) setattr(int64, method_name, _make_int64_op_method(orig_method))
for _op_name in ( for _op_name in ("add", "sub", "mul", "floordiv", "mod",
"add", "sub", "mul", "floordiv", "mod",
"pow", "lshift", "rshift", "lshift", "pow", "lshift", "rshift", "lshift",
"and", "xor", "or"): "and", "xor", "or"):
op_method = getattr(int, "__" + _op_name + "__") op_method = getattr(int, "__" + _op_name + "__")
setattr(int64, "__i" + _op_name + "__", _make_int64_op_method(op_method)) setattr(int64, "__i" + _op_name + "__", _make_int64_op_method(op_method))
def round64(x): def round64(x):
return int64(round(x)) return int64(round(x))
def _make_kernel_ro(value): def _make_kernel_ro(value):
return isinstance(value, (bool, int, int64, float, Fraction, units.Quantity)) return isinstance(
value, (bool, int, int64, float, Fraction, units.Quantity))
class AutoContext: class AutoContext:
parameters = "" parameters = ""
@ -76,25 +79,31 @@ class AutoContext:
""" Overload this function to add sub-experiments""" """ Overload this function to add sub-experiments"""
pass pass
KernelFunctionInfo = namedtuple("KernelFunctionInfo", "core_name k_function") KernelFunctionInfo = namedtuple("KernelFunctionInfo", "core_name k_function")
def kernel(arg): def kernel(arg):
if isinstance(arg, str): if isinstance(arg, str):
def real_decorator(k_function): def real_decorator(k_function):
def run_on_core(exp, *k_args, **k_kwargs): def run_on_core(exp, *k_args, **k_kwargs):
getattr(exp, arg).run(k_function, ((exp,) + k_args), k_kwargs) getattr(exp, arg).run(k_function, ((exp,) + k_args), k_kwargs)
run_on_core.k_function_info = KernelFunctionInfo(core_name=arg, k_function=k_function) run_on_core.k_function_info = KernelFunctionInfo(
core_name=arg, k_function=k_function)
return run_on_core return run_on_core
return real_decorator return real_decorator
else: else:
def run_on_core(exp, *k_args, **k_kwargs): def run_on_core(exp, *k_args, **k_kwargs):
exp.core.run(arg, ((exp,) + k_args), k_kwargs) exp.core.run(arg, ((exp,) + k_args), k_kwargs)
run_on_core.k_function_info = KernelFunctionInfo(core_name="core", k_function=arg) run_on_core.k_function_info = KernelFunctionInfo(
core_name="core", k_function=arg)
return run_on_core return run_on_core
class _DummyTimeManager: class _DummyTimeManager:
def _not_implemented(self, *args, **kwargs): def _not_implemented(self, *args, **kwargs):
raise NotImplementedError("Attempted to interpret kernel without a time manager") raise NotImplementedError(
"Attempted to interpret kernel without a time manager")
enter_sequential = _not_implemented enter_sequential = _not_implemented
enter_parallel = _not_implemented enter_parallel = _not_implemented
@ -105,16 +114,20 @@ class _DummyTimeManager:
_time_manager = _DummyTimeManager() _time_manager = _DummyTimeManager()
def set_time_manager(time_manager): def set_time_manager(time_manager):
global _time_manager global _time_manager
_time_manager = time_manager _time_manager = time_manager
class _DummySyscallManager: class _DummySyscallManager:
def do(self, *args): def do(self, *args):
raise NotImplementedError("Attempted to interpret kernel without a syscall manager") raise NotImplementedError(
"Attempted to interpret kernel without a syscall manager")
_syscall_manager = _DummySyscallManager() _syscall_manager = _DummySyscallManager()
def set_syscall_manager(syscall_manager): def set_syscall_manager(syscall_manager):
global _syscall_manager global _syscall_manager
_syscall_manager = syscall_manager _syscall_manager = syscall_manager
@ -123,6 +136,7 @@ def set_syscall_manager(syscall_manager):
kernel_globals = "sequential", "parallel", "delay", "now", "at", "syscall" kernel_globals = "sequential", "parallel", "delay", "now", "at", "syscall"
class _Sequential: class _Sequential:
def __enter__(self): def __enter__(self):
_time_manager.enter_sequential() _time_manager.enter_sequential()
@ -131,6 +145,7 @@ class _Sequential:
_time_manager.exit() _time_manager.exit()
sequential = _Sequential() sequential = _Sequential()
class _Parallel: class _Parallel:
def __enter__(self): def __enter__(self):
_time_manager.enter_parallel() _time_manager.enter_parallel()
@ -139,14 +154,18 @@ class _Parallel:
_time_manager.exit() _time_manager.exit()
parallel = _Parallel() parallel = _Parallel()
def delay(duration): def delay(duration):
_time_manager.take_time(duration) _time_manager.take_time(duration)
def now(): def now():
return _time_manager.get_time() return _time_manager.get_time()
def at(time): def at(time):
_time_manager.set_time(time) _time_manager.set_time(time)
def syscall(*args): def syscall(*args):
return _syscall_manager.do(*args) return _syscall_manager.do(*args)

View File

@ -1,14 +1,17 @@
from collections import namedtuple from collections import namedtuple
from fractions import Fraction from fractions import Fraction
_prefixes_str = "pnum_kMG" _prefixes_str = "pnum_kMG"
_smallest_prefix = Fraction(1, 10**12) _smallest_prefix = Fraction(1, 10**12)
Unit = namedtuple("Unit", "name") Unit = namedtuple("Unit", "name")
class DimensionError(Exception): class DimensionError(Exception):
pass pass
class Quantity: class Quantity:
def __init__(self, amount, unit): def __init__(self, amount, unit):
self.amount = amount self.amount = amount
@ -32,14 +35,17 @@ class Quantity:
else: else:
return str(r_amount) + " " + self.unit.name return str(r_amount) + " " + self.unit.name
# mul/div
def __mul__(self, other): def __mul__(self, other):
if isinstance(other, Quantity): if isinstance(other, Quantity):
return NotImplemented return NotImplemented
return Quantity(self.amount*other, self.unit) return Quantity(self.amount*other, self.unit)
def __rmul__(self, other): def __rmul__(self, other):
if isinstance(other, Quantity): if isinstance(other, Quantity):
return NotImplemented return NotImplemented
return Quantity(other*self.amount, self.unit) return Quantity(other*self.amount, self.unit)
def __truediv__(self, other): def __truediv__(self, other):
if isinstance(other, Quantity): if isinstance(other, Quantity):
if other.unit == self.unit: if other.unit == self.unit:
@ -48,6 +54,7 @@ class Quantity:
return NotImplemented return NotImplemented
else: else:
return Quantity(self.amount/other, self.unit) return Quantity(self.amount/other, self.unit)
def __floordiv__(self, other): def __floordiv__(self, other):
if isinstance(other, Quantity): if isinstance(other, Quantity):
if other.unit == self.unit: if other.unit == self.unit:
@ -57,55 +64,65 @@ class Quantity:
else: else:
return Quantity(self.amount//other, self.unit) return Quantity(self.amount//other, self.unit)
# unary ops
def __neg__(self): def __neg__(self):
return Quantity(-self.amount, self.unit) return Quantity(-self.amount, self.unit)
def __pos__(self):
return Quantity(self.amount, self.unit)
# add/sub
def __add__(self, other): def __add__(self, other):
if self.unit != other.unit: if self.unit != other.unit:
raise DimensionError raise DimensionError
return Quantity(self.amount + other.amount, self.unit) return Quantity(self.amount + other.amount, self.unit)
def __radd__(self, other): def __radd__(self, other):
if self.unit != other.unit: if self.unit != other.unit:
raise DimensionError raise DimensionError
return Quantity(other.amount + self.amount, self.unit) return Quantity(other.amount + self.amount, self.unit)
def __sub__(self, other): def __sub__(self, other):
if self.unit != other.unit: if self.unit != other.unit:
raise DimensionError raise DimensionError
return Quantity(self.amount - other.amount, self.unit) return Quantity(self.amount - other.amount, self.unit)
def __rsub__(self, other): def __rsub__(self, other):
if self.unit != other.unit: if self.unit != other.unit:
raise DimensionError raise DimensionError
return Quantity(other.amount - self.amount, self.unit) return Quantity(other.amount - self.amount, self.unit)
# comparisons
def __lt__(self, other): def __lt__(self, other):
if self.unit != other.unit: if self.unit != other.unit:
raise DimensionError raise DimensionError
return self.amount < other.amount return self.amount < other.amount
def __le__(self, other): def __le__(self, other):
if self.unit != other.unit: if self.unit != other.unit:
raise DimensionError raise DimensionError
return self.amount <= other.amount return self.amount <= other.amount
def __eq__(self, other): def __eq__(self, other):
if self.unit != other.unit: if self.unit != other.unit:
raise DimensionError raise DimensionError
return self.amount == other.amount return self.amount == other.amount
def __ne__(self, other): def __ne__(self, other):
if self.unit != other.unit: if self.unit != other.unit:
raise DimensionError raise DimensionError
return self.amount != other.amount return self.amount != other.amount
def __gt__(self, other): def __gt__(self, other):
if self.unit != other.unit: if self.unit != other.unit:
raise DimensionError raise DimensionError
return self.amount > other.amount return self.amount > other.amount
def __ge__(self, other): def __ge__(self, other):
if self.unit != other.unit: if self.unit != other.unit:
raise DimensionError raise DimensionError
return self.amount >= other.amount return self.amount >= other.amount
def check_unit(value, unit):
if not isinstance(value, Quantity) or value.unit != unit:
raise DimensionError
return value.amount
def _register_unit(name, prefixes): def _register_unit(name, prefixes):
unit = Unit(name) unit = Unit(name)

View File

@ -4,10 +4,12 @@ from artiq.language.core import AutoContext, delay
from artiq.language import units from artiq.language import units
from artiq.sim import time from artiq.sim import time
class Core: class Core:
def run(self, k_function, k_args, k_kwargs): def run(self, k_function, k_args, k_kwargs):
return k_function(*k_args, **k_kwargs) return k_function(*k_args, **k_kwargs)
class Input(AutoContext): class Input(AutoContext):
parameters = "name" parameters = "name"
implicit_core = False implicit_core = False
@ -26,6 +28,7 @@ class Input(AutoContext):
delay(duration) delay(duration)
return result return result
class WaveOutput(AutoContext): class WaveOutput(AutoContext):
parameters = "name" parameters = "name"
implicit_core = False implicit_core = False
@ -34,6 +37,7 @@ class WaveOutput(AutoContext):
time.manager.event(("pulse", self.name, frequency, duration)) time.manager.event(("pulse", self.name, frequency, duration))
delay(duration) delay(duration)
class VoltageOutput(AutoContext): class VoltageOutput(AutoContext):
parameters = "name" parameters = "name"
implicit_core = False implicit_core = False

View File

@ -3,6 +3,7 @@ from operator import itemgetter
from artiq.language.units import * from artiq.language.units import *
from artiq.language import core as core_language from artiq.language import core as core_language
class SequentialTimeContext: class SequentialTimeContext:
def __init__(self, current_time): def __init__(self, current_time):
self.current_time = current_time self.current_time = current_time
@ -12,6 +13,7 @@ class SequentialTimeContext:
self.current_time += amount self.current_time += amount
self.block_duration += amount self.block_duration += amount
class ParallelTimeContext: class ParallelTimeContext:
def __init__(self, current_time): def __init__(self, current_time):
self.current_time = current_time self.current_time = current_time
@ -21,6 +23,7 @@ class ParallelTimeContext:
if amount > self.block_duration: if amount > self.block_duration:
self.block_duration = amount self.block_duration = amount
class Manager: class Manager:
def __init__(self): def __init__(self):
self.stack = [SequentialTimeContext(0*s)] self.stack = [SequentialTimeContext(0*s)]

View File

@ -1,6 +1,7 @@
from artiq.language.units import * from artiq.language.units import *
from artiq.language.core import * from artiq.language.core import *
class AluminumSpectroscopy(AutoContext): class AluminumSpectroscopy(AutoContext):
parameters = "mains_sync laser_cooling spectroscopy spectroscopy_b state_detection pmt \ parameters = "mains_sync laser_cooling spectroscopy spectroscopy_b state_detection pmt \
spectroscopy_freq photon_limit_low photon_limit_high" spectroscopy_freq photon_limit_low photon_limit_high"
@ -24,12 +25,14 @@ class AluminumSpectroscopy(AutoContext):
with parallel: with parallel:
self.state_detection.pulse(100*MHz, 10*us) self.state_detection.pulse(100*MHz, 10*us)
photon_count = self.pmt.count_gate(10*us) photon_count = self.pmt.count_gate(10*us)
if photon_count < self.photon_limit_low or photon_count > self.photon_limit_high: if (photon_count < self.photon_limit_low
or photon_count > self.photon_limit_high):
break break
if photon_count < self.photon_limit_low: if photon_count < self.photon_limit_low:
state_0_count += 1 state_0_count += 1
return state_0_count return state_0_count
if __name__ == "__main__": if __name__ == "__main__":
from artiq.sim import devices as sd from artiq.sim import devices as sd
from artiq.sim import time from artiq.sim import time

View File

@ -1,8 +1,10 @@
from artiq.language.units import * from artiq.language.units import *
from artiq.language.core import * from artiq.language.core import *
my_range = range my_range = range
class CompilerTest(AutoContext): class CompilerTest(AutoContext):
parameters = "a b A B" parameters = "a b A B"
@ -27,15 +29,20 @@ class CompilerTest(AutoContext):
self.B.pulse(100*MHz, t2) self.B.pulse(100*MHz, t2)
self.print_done() self.print_done()
if __name__ == "__main__": if __name__ == "__main__":
from artiq.devices import corecom_dummy, core, dds_core from artiq.devices import corecom_dummy, core, dds_core
coredev = core.Core(corecom_dummy.CoreCom()) coredev = core.Core(corecom_dummy.CoreCom())
exp = CompilerTest( exp = CompilerTest(
core=coredev, core=coredev,
a=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=0, rtio_channel=0), a=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
b=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=1, rtio_channel=1), reg_channel=0, rtio_channel=0),
A=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=2, rtio_channel=2), b=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
B=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=3, rtio_channel=3) reg_channel=1, rtio_channel=1),
A=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
reg_channel=2, rtio_channel=2),
B=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
reg_channel=3, rtio_channel=3)
) )
exp.run(3, 100*us) exp.run(3, 100*us)

View File

@ -1,6 +1,7 @@
from artiq.language.core import AutoContext, kernel from artiq.language.core import AutoContext, kernel
from artiq.devices import corecom_serial, core, gpio_core from artiq.devices import corecom_serial, core, gpio_core
class CompilerTest(AutoContext): class CompilerTest(AutoContext):
parameters = "led" parameters = "led"
@ -27,6 +28,7 @@ class CompilerTest(AutoContext):
x += 1 x += 1
self.led.set(0) self.led.set(0)
if __name__ == "__main__": if __name__ == "__main__":
with corecom_serial.CoreCom() as com: with corecom_serial.CoreCom() as com:
coredev = core.Core(com) coredev = core.Core(com)

View File

@ -2,6 +2,7 @@ from artiq.language.units import *
from artiq.language.core import * from artiq.language.core import *
from artiq.devices import corecom_serial, core, dds_core, gpio_core from artiq.devices import corecom_serial, core, dds_core, gpio_core
class DDSTest(AutoContext): class DDSTest(AutoContext):
parameters = "a b c d led" parameters = "a b c d led"
@ -23,15 +24,20 @@ class DDSTest(AutoContext):
i += 1 i += 1
self.led.set(0) self.led.set(0)
if __name__ == "__main__": if __name__ == "__main__":
with corecom_serial.CoreCom() as com: with corecom_serial.CoreCom() as com:
coredev = core.Core(com) coredev = core.Core(com)
exp = DDSTest( exp = DDSTest(
core=coredev, core=coredev,
a=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=0, rtio_channel=0), a=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
b=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=1, rtio_channel=1), reg_channel=0, rtio_channel=0),
c=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=2, rtio_channel=2), b=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
d=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=3, rtio_channel=3), reg_channel=1, rtio_channel=1),
c=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
reg_channel=2, rtio_channel=2),
d=dds_core.DDS(core=coredev, dds_sysclk=1*GHz,
reg_channel=3, rtio_channel=3),
led=gpio_core.GPIOOut(core=coredev, channel=1) led=gpio_core.GPIOOut(core=coredev, channel=1)
) )
exp.run() exp.run()

View File

@ -1,6 +1,7 @@
from artiq.language.units import * from artiq.language.units import *
from artiq.language.core import * from artiq.language.core import *
class SimpleSimulation(AutoContext): class SimpleSimulation(AutoContext):
parameters = "a b c d" parameters = "a b c d"
@ -14,6 +15,7 @@ class SimpleSimulation(AutoContext):
self.c.pulse(300*MHz, 10*us) self.c.pulse(300*MHz, 10*us)
self.d.pulse(400*MHz, 20*us) self.d.pulse(400*MHz, 20*us)
if __name__ == "__main__": if __name__ == "__main__":
from artiq.sim import devices as sd from artiq.sim import devices as sd
from artiq.sim import time from artiq.sim import time

View File

@ -2,6 +2,7 @@ from artiq.language.units import *
from artiq.language.core import * from artiq.language.core import *
from artiq.devices import corecom_serial, core from artiq.devices import corecom_serial, core
class DummyPulse(AutoContext): class DummyPulse(AutoContext):
parameters = "name" parameters = "name"
@ -17,6 +18,7 @@ class DummyPulse(AutoContext):
delay(duration) delay(duration)
self.print_off(int(now())) self.print_off(int(now()))
class TimeTest(AutoContext): class TimeTest(AutoContext):
parameters = "a b c d" parameters = "a b c d"
@ -33,6 +35,7 @@ class TimeTest(AutoContext):
self.d.pulse(400+i, 20*us) self.d.pulse(400+i, 20*us)
i += 1 i += 1
if __name__ == "__main__": if __name__ == "__main__":
with corecom_serial.CoreCom() as com: with corecom_serial.CoreCom() as com:
coredev = core.Core(com) coredev = core.Core(com)

View File

@ -4,6 +4,7 @@ from migen.bus import wishbone
from migen.bus.transactions import * from migen.bus.transactions import *
from migen.sim.generic import run_simulation from migen.sim.generic import run_simulation
class AD9858(Module): class AD9858(Module):
"""Wishbone interface to the AD9858 DDS chip. """Wishbone interface to the AD9858 DDS chip.
@ -37,7 +38,7 @@ class AD9858(Module):
bus = wishbone.Interface() bus = wishbone.Interface()
self.bus = bus self.bus = bus
### # # #
dts = TSTriple(8) dts = TSTriple(8)
self.specials += dts.get_tristate(pads.d) self.specials += dts.get_tristate(pads.d)
@ -157,6 +158,7 @@ class AD9858(Module):
NextState("IDLE") NextState("IDLE")
) )
def _test_gen(): def _test_gen():
# Test external bus writes # Test external bus writes
yield TWrite(4, 2) yield TWrite(4, 2)
@ -173,6 +175,7 @@ def _test_gen():
yield TWrite(65, 0xff) yield TWrite(65, 0xff)
yield yield
class _TestPads: class _TestPads:
def __init__(self): def __init__(self):
self.a = Signal(6) self.a = Signal(6)
@ -184,6 +187,7 @@ class _TestPads:
self.rd_n = Signal() self.rd_n = Signal()
self.rst_n = Signal() self.rst_n = Signal()
class _TB(Module): class _TB(Module):
def __init__(self): def __init__(self):
pads = _TestPads() pads = _TestPads()
@ -191,5 +195,6 @@ class _TB(Module):
self.submodules.initiator = wishbone.Initiator(_test_gen()) self.submodules.initiator = wishbone.Initiator(_test_gen())
self.submodules.interconnect = wishbone.InterconnectPointToPoint(self.initiator.bus, self.dut.bus) self.submodules.interconnect = wishbone.InterconnectPointToPoint(self.initiator.bus, self.dut.bus)
if __name__ == "__main__": if __name__ == "__main__":
run_simulation(_TB(), vcd_name="ad9858.vcd") run_simulation(_TB(), vcd_name="ad9858.vcd")

View File

@ -5,6 +5,7 @@ from migen.genlib.cdc import MultiReg
from artiqlib.rtio.rbus import get_fine_ts_width from artiqlib.rtio.rbus import get_fine_ts_width
class _RTIOBankO(Module): class _RTIOBankO(Module):
def __init__(self, rbus, counter_width, fine_ts_width, fifo_depth, counter_init): def __init__(self, rbus, counter_width, fine_ts_width, fifo_depth, counter_init):
self.sel = Signal(max=len(rbus)) self.sel = Signal(max=len(rbus))
@ -15,7 +16,7 @@ class _RTIOBankO(Module):
self.underflow = Signal() self.underflow = Signal()
self.level = Signal(bits_for(fifo_depth)) self.level = Signal(bits_for(fifo_depth))
### # # #
counter = Signal(counter_width, reset=counter_init) counter = Signal(counter_width, reset=counter_init)
self.sync += [ self.sync += [
@ -53,6 +54,7 @@ class _RTIOBankO(Module):
selfifo = Array(fifos)[self.sel] selfifo = Array(fifos)[self.sel]
self.comb += self.writable.eq(selfifo.writable), self.level.eq(selfifo.level) self.comb += self.writable.eq(selfifo.writable), self.level.eq(selfifo.level)
class _RTIOBankI(Module): class _RTIOBankI(Module):
def __init__(self, rbus, counter_width, fine_ts_width, fifo_depth): def __init__(self, rbus, counter_width, fine_ts_width, fifo_depth):
self.sel = Signal(max=len(rbus)) self.sel = Signal(max=len(rbus))
@ -116,6 +118,7 @@ class _RTIOBankI(Module):
self.overflow.eq(Array(overflows)[self.sel]) self.overflow.eq(Array(overflows)[self.sel])
] ]
class RTIO(Module, AutoCSR): class RTIO(Module, AutoCSR):
def __init__(self, phy, counter_width=32, ofifo_depth=8, ififo_depth=8): def __init__(self, phy, counter_width=32, ofifo_depth=8, ififo_depth=8):
fine_ts_width = get_fine_ts_width(phy.rbus) fine_ts_width = get_fine_ts_width(phy.rbus)

View File

@ -3,12 +3,13 @@ from migen.genlib.cdc import MultiReg
from artiqlib.rtio.rbus import create_rbus from artiqlib.rtio.rbus import create_rbus
class SimplePHY(Module): class SimplePHY(Module):
def __init__(self, pads, output_only_pads=set()): def __init__(self, pads, output_only_pads=set()):
self.rbus = create_rbus(0, pads, output_only_pads) self.rbus = create_rbus(0, pads, output_only_pads)
self.loopback_latency = 3 self.loopback_latency = 3
### # # #
for pad, chif in zip(pads, self.rbus): for pad, chif in zip(pads, self.rbus):
o_pad = Signal() o_pad = Signal()