diff --git a/artiq/compiler/fold_constants.py b/artiq/compiler/fold_constants.py index 71f821917..ddebbbf71 100644 --- a/artiq/compiler/fold_constants.py +++ b/artiq/compiler/fold_constants.py @@ -1,81 +1,86 @@ -import ast, operator +import ast +import operator from artiq.compiler.tools import * from artiq.language.core import int64, round64 + _ast_unops = { - ast.Invert: operator.inv, - ast.Not: operator.not_, - ast.UAdd: operator.pos, - ast.USub: operator.neg + ast.Invert: operator.inv, + ast.Not: operator.not_, + ast.UAdd: operator.pos, + ast.USub: operator.neg } + _ast_binops = { - ast.Add: operator.add, - ast.Sub: operator.sub, - ast.Mult: operator.mul, - ast.Div: operator.truediv, - ast.FloorDiv: operator.floordiv, - ast.Mod: operator.mod, - ast.Pow: operator.pow, - ast.LShift: operator.lshift, - ast.RShift: operator.rshift, - ast.BitOr: operator.or_, - ast.BitXor: operator.xor, - ast.BitAnd: operator.and_ + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.LShift: operator.lshift, + ast.RShift: operator.rshift, + ast.BitOr: operator.or_, + ast.BitXor: operator.xor, + ast.BitAnd: operator.and_ } + class _ConstantFolder(ast.NodeTransformer): - def visit_UnaryOp(self, node): - self.generic_visit(node) - try: - operand = eval_constant(node.operand) - except NotConstant: - return node - try: - op = _ast_unops[type(node.op)] - except KeyError: - return node - try: - result = value_to_ast(op(operand)) - except: - return node - return ast.copy_location(result, node) + def visit_UnaryOp(self, node): + self.generic_visit(node) + try: + operand = eval_constant(node.operand) + except NotConstant: + return node + try: + op = _ast_unops[type(node.op)] + except KeyError: + return node + try: + result = value_to_ast(op(operand)) + except: + return node + return ast.copy_location(result, node) - def visit_BinOp(self, node): - self.generic_visit(node) - try: - left, right = eval_constant(node.left), eval_constant(node.right) - except NotConstant: - return node - try: - op = _ast_binops[type(node.op)] - except KeyError: - return node - try: - result = value_to_ast(op(left, right)) - except: - return node - return ast.copy_location(result, node) + def visit_BinOp(self, node): + self.generic_visit(node) + try: + left, right = eval_constant(node.left), eval_constant(node.right) + except NotConstant: + return node + try: + op = _ast_binops[type(node.op)] + except KeyError: + return node + try: + result = value_to_ast(op(left, right)) + except: + return node + return ast.copy_location(result, node) + + def visit_Call(self, node): + self.generic_visit(node) + fn = node.func.id + constant_ops = { + "int": int, + "int64": int64, + "round": round, + "round64": round64 + } + if fn in constant_ops: + try: + arg = eval_constant(node.args[0]) + except NotConstant: + return node + result = value_to_ast(constant_ops[fn](arg)) + return ast.copy_location(result, node) + else: + return node - def visit_Call(self, node): - self.generic_visit(node) - fn = node.func.id - constant_ops = { - "int": int, - "int64": int64, - "round": round, - "round64": round64 - } - if fn in constant_ops: - try: - arg = eval_constant(node.args[0]) - except NotConstant: - return node - result = value_to_ast(constant_ops[fn](arg)) - return ast.copy_location(result, node) - else: - return node def fold_constants(node): - _ConstantFolder().visit(node) + _ConstantFolder().visit(node) diff --git a/artiq/compiler/inline.py b/artiq/compiler/inline.py index 7d674ac80..3e5af0b70 100644 --- a/artiq/compiler/inline.py +++ b/artiq/compiler/inline.py @@ -1,229 +1,253 @@ from collections import namedtuple, defaultdict from fractions import Fraction -import inspect, textwrap, ast +import inspect +import textwrap +import ast from artiq.compiler.tools import eval_ast, value_to_ast from artiq.language import core as core_language from artiq.language import units + _UserVariable = namedtuple("_UserVariable", "name") + def _is_in_attr_list(obj, attr, al): - if not hasattr(obj, al): - return False - return attr in getattr(obj, al).split() + if not hasattr(obj, al): + return False + return attr in getattr(obj, al).split() + class _ReferenceManager: - def __init__(self): - # (id(obj), funcname, local) -> _UserVariable(name) / ast / constant_object - # local is None for kernel attributes - self.to_inlined = dict() - # inlined_name -> use_count - self.use_count = dict() - self.rpc_map = defaultdict(lambda: len(self.rpc_map)) - self.kernel_attr_init = [] + def __init__(self): + # (id(obj), funcname, local) + # -> _UserVariable(name) / ast / constant_object + # local is None for kernel attributes + self.to_inlined = dict() + # inlined_name -> use_count + self.use_count = dict() + self.rpc_map = defaultdict(lambda: len(self.rpc_map)) + self.kernel_attr_init = [] - # reserved names - for kg in core_language.kernel_globals: - self.use_count[kg] = 1 - for name in "int", "round", "int64", "round64", \ - "range", "Fraction", "Quantity", \ - "s_unit", "Hz_unit", "microcycle_unit": - self.use_count[name] = 1 + # reserved names + for kg in core_language.kernel_globals: + self.use_count[kg] = 1 + for name in ("int", "round", "int64", "round64", + "range", "Fraction", "Quantity", + "s_unit", "Hz_unit", "microcycle_unit"): + self.use_count[name] = 1 - def new_name(self, base_name): - if base_name[-1].isdigit(): - base_name += "_" - if base_name in self.use_count: - r = base_name + str(self.use_count[base_name]) - self.use_count[base_name] += 1 - return r - else: - self.use_count[base_name] = 1 - return base_name + def new_name(self, base_name): + if base_name[-1].isdigit(): + base_name += "_" + if base_name in self.use_count: + r = base_name + str(self.use_count[base_name]) + self.use_count[base_name] += 1 + return r + else: + self.use_count[base_name] = 1 + return base_name - def get(self, obj, funcname, ref): - store = isinstance(ref.ctx, ast.Store) + def get(self, obj, funcname, ref): + store = isinstance(ref.ctx, ast.Store) - if isinstance(ref, ast.Name): - key = (id(obj), funcname, ref.id) - try: - return self.to_inlined[key] - except KeyError: - if store: - ival = _UserVariable(self.new_name(ref.id)) - self.to_inlined[key] = ival - return ival + if isinstance(ref, ast.Name): + key = (id(obj), funcname, ref.id) + try: + return self.to_inlined[key] + except KeyError: + if store: + ival = _UserVariable(self.new_name(ref.id)) + self.to_inlined[key] = ival + return ival - if isinstance(ref, ast.Attribute) and isinstance(ref.value, ast.Name): - try: - value = self.to_inlined[(id(obj), funcname, ref.value.id)] - except KeyError: - pass - else: - if _is_in_attr_list(value, ref.attr, "kernel_attr_ro"): - if store: - raise TypeError("Attempted to assign to read-only kernel attribute") - return getattr(value, ref.attr) - if _is_in_attr_list(value, ref.attr, "kernel_attr"): - key = (id(value), ref.attr, None) - try: - ival = self.to_inlined[key] - assert(isinstance(ival, _UserVariable)) - except KeyError: - iname = self.new_name(ref.attr) - ival = _UserVariable(iname) - self.to_inlined[key] = ival - a = value_to_ast(getattr(value, ref.attr)) - if a is None: - raise NotImplementedError("Cannot represent initial value of kernel attribute") - self.kernel_attr_init.append(ast.Assign( - [ast.Name(iname, ast.Store())], a)) - return ival + if isinstance(ref, ast.Attribute) and isinstance(ref.value, ast.Name): + try: + value = self.to_inlined[(id(obj), funcname, ref.value.id)] + except KeyError: + pass + else: + if _is_in_attr_list(value, ref.attr, "kernel_attr_ro"): + if store: + raise TypeError( + "Attempted to assign to read-only" + " kernel attribute") + return getattr(value, ref.attr) + if _is_in_attr_list(value, ref.attr, "kernel_attr"): + key = (id(value), ref.attr, None) + try: + ival = self.to_inlined[key] + assert(isinstance(ival, _UserVariable)) + except KeyError: + iname = self.new_name(ref.attr) + ival = _UserVariable(iname) + self.to_inlined[key] = ival + a = value_to_ast(getattr(value, ref.attr)) + if a is None: + raise NotImplementedError( + "Cannot represent initial value" + " of kernel attribute") + self.kernel_attr_init.append(ast.Assign( + [ast.Name(iname, ast.Store())], a)) + return ival - if not store: - evd = self.get_constants(obj, funcname) - evd.update(inspect.getmodule(obj).__dict__) - return eval_ast(ref, evd) - else: - raise KeyError + if not store: + evd = self.get_constants(obj, funcname) + evd.update(inspect.getmodule(obj).__dict__) + return eval_ast(ref, evd) + else: + raise KeyError - def set(self, obj, funcname, name, value): - self.to_inlined[(id(obj), funcname, name)] = value + def set(self, obj, funcname, name, value): + self.to_inlined[(id(obj), funcname, name)] = value + + def get_constants(self, r_obj, r_funcname): + return { + local: v for (objid, funcname, local), v + in self.to_inlined.items() + if objid == id(r_obj) + and funcname == r_funcname + and not isinstance(v, (_UserVariable, ast.AST))} - def get_constants(self, r_obj, r_funcname): - return {local: v for (objid, funcname, local), v - in self.to_inlined.items() - if objid == id(r_obj) - and funcname == r_funcname - and not isinstance(v, (_UserVariable, ast.AST))} _embeddable_calls = { - core_language.delay, core_language.at, core_language.now, - core_language.syscall, - range, int, round, core_language.int64, core_language.round64, - Fraction, units.Quantity + core_language.delay, core_language.at, core_language.now, + core_language.syscall, + range, int, round, core_language.int64, core_language.round64, + Fraction, units.Quantity } + class _ReferenceReplacer(ast.NodeTransformer): - def __init__(self, core, rm, obj, funcname): - self.core = core - self.rm = rm - self.obj = obj - self.funcname = funcname + def __init__(self, core, rm, obj, funcname): + self.core = core + self.rm = rm + self.obj = obj + self.funcname = funcname - def visit_ref(self, node): - store = isinstance(node.ctx, ast.Store) - ival = self.rm.get(self.obj, self.funcname, node) - if isinstance(ival, _UserVariable): - newnode = ast.Name(ival.name, node.ctx) - elif isinstance(ival, ast.AST): - assert(not store) - newnode = ival - else: - if store: - raise NotImplementedError("Cannot turn object into user variable") - else: - newnode = value_to_ast(ival) - if newnode is None: - raise NotImplementedError("Cannot represent inlined value") - return ast.copy_location(newnode, node) + def visit_ref(self, node): + store = isinstance(node.ctx, ast.Store) + ival = self.rm.get(self.obj, self.funcname, node) + if isinstance(ival, _UserVariable): + newnode = ast.Name(ival.name, node.ctx) + elif isinstance(ival, ast.AST): + assert(not store) + newnode = ival + else: + if store: + raise NotImplementedError( + "Cannot turn object into user variable") + else: + newnode = value_to_ast(ival) + if newnode is None: + raise NotImplementedError( + "Cannot represent inlined value") + return ast.copy_location(newnode, node) - visit_Name = visit_ref - visit_Attribute = visit_ref - visit_Subscript = visit_ref + visit_Name = visit_ref + visit_Attribute = visit_ref + visit_Subscript = visit_ref - def visit_Call(self, node): - func = self.rm.get(self.obj, self.funcname, node.func) - new_args = [self.visit(arg) for arg in node.args] + def visit_Call(self, node): + func = self.rm.get(self.obj, self.funcname, node.func) + new_args = [self.visit(arg) for arg in node.args] - if func in _embeddable_calls: - new_func = ast.Name(func.__name__, ast.Load()) - return ast.copy_location( - ast.Call(func=new_func, args=new_args, - keywords=[], starargs=None, kwargs=None), - node) - elif hasattr(func, "k_function_info") and getattr(func.__self__, func.k_function_info.core_name) is self.core: - args = [func.__self__] + new_args - inlined, _ = inline(self.core, func.k_function_info.k_function, args, dict(), self.rm) - return inlined.body - else: - args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])] - args += new_args - return ast.copy_location( - ast.Call(func=ast.Name("syscall", ast.Load()), - args=args, keywords=[], starargs=None, kwargs=None), - node) + if func in _embeddable_calls: + new_func = ast.Name(func.__name__, ast.Load()) + return ast.copy_location( + ast.Call(func=new_func, args=new_args, + keywords=[], starargs=None, kwargs=None), + node) + elif (hasattr(func, "k_function_info") + and getattr(func.__self__, func.k_function_info.core_name) + is self.core): + args = [func.__self__] + new_args + inlined, _ = inline(self.core, func.k_function_info.k_function, + args, dict(), self.rm) + return inlined.body + else: + args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])] + args += new_args + return ast.copy_location( + ast.Call(func=ast.Name("syscall", ast.Load()), + args=args, keywords=[], starargs=None, kwargs=None), + node) - def visit_Expr(self, node): - if isinstance(node.value, ast.Call): - r = self.visit_Call(node.value) - if isinstance(r, list): - return r - else: - node.value = r - return node - else: - self.generic_visit(node) - return node + def visit_Expr(self, node): + if isinstance(node.value, ast.Call): + r = self.visit_Call(node.value) + if isinstance(r, list): + return r + else: + node.value = r + return node + else: + self.generic_visit(node) + return node + + def visit_FunctionDef(self, node): + node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[], + kw_defaults=[], kwarg=None, defaults=[]) + node.decorator_list = [] + self.generic_visit(node) + return node - def visit_FunctionDef(self, node): - node.args = ast.arguments(args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]) - node.decorator_list = [] - self.generic_visit(node) - return node class _ListReadOnlyParams(ast.NodeVisitor): - def visit_FunctionDef(self, node): - if hasattr(self, "read_only_params"): - raise ValueError("More than one function definition") - self.read_only_params = {arg.arg for arg in node.args.args} - self.generic_visit(node) + def visit_FunctionDef(self, node): + if hasattr(self, "read_only_params"): + raise ValueError("More than one function definition") + self.read_only_params = {arg.arg for arg in node.args.args} + self.generic_visit(node) + + def visit_Name(self, node): + if isinstance(node.ctx, ast.Store): + try: + self.read_only_params.remove(node.id) + except KeyError: + pass - def visit_Name(self, node): - if isinstance(node.ctx, ast.Store): - try: - self.read_only_params.remove(node.id) - except KeyError: - pass def _list_read_only_params(funcdef): - lrp = _ListReadOnlyParams() - lrp.visit(funcdef) - return lrp.read_only_params + lrp = _ListReadOnlyParams() + lrp.visit(funcdef) + return lrp.read_only_params + def _initialize_function_params(funcdef, k_args, k_kwargs, rm): - obj = k_args[0] - funcname = funcdef.name - param_init = [] - rop = _list_read_only_params(funcdef) - for arg_ast, arg_value in zip(funcdef.args.args, k_args): - arg_name = arg_ast.arg - if arg_name in rop: - rm.set(obj, funcname, arg_name, arg_value) - else: - target = rm.get(obj, funcname, ast.Name(arg_name, ast.Store())) - value = value_to_ast(arg_value) - param_init.append(ast.Assign(targets=[target], value=value)) - return param_init + obj = k_args[0] + funcname = funcdef.name + param_init = [] + rop = _list_read_only_params(funcdef) + for arg_ast, arg_value in zip(funcdef.args.args, k_args): + arg_name = arg_ast.arg + if arg_name in rop: + rm.set(obj, funcname, arg_name, arg_value) + else: + target = rm.get(obj, funcname, ast.Name(arg_name, ast.Store())) + value = value_to_ast(arg_value) + param_init.append(ast.Assign(targets=[target], value=value)) + return param_init + def inline(core, k_function, k_args, k_kwargs, rm=None): - init_kernel_attr = rm is None - if rm is None: - rm = _ReferenceManager() + init_kernel_attr = rm is None + if rm is None: + rm = _ReferenceManager() - funcdef = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0] + funcdef = ast.parse(textwrap.dedent(inspect.getsource(k_function))).body[0] - param_init = _initialize_function_params(funcdef, k_args, k_kwargs, rm) + param_init = _initialize_function_params(funcdef, k_args, k_kwargs, rm) - obj = k_args[0] - funcname = funcdef.name - rr = _ReferenceReplacer(core, rm, obj, funcname) - rr.visit(funcdef) + obj = k_args[0] + funcname = funcdef.name + rr = _ReferenceReplacer(core, rm, obj, funcname) + rr.visit(funcdef) - funcdef.body[0:0] = param_init - if init_kernel_attr: - funcdef.body[0:0] = rm.kernel_attr_init + funcdef.body[0:0] = param_init + if init_kernel_attr: + funcdef.body[0:0] = rm.kernel_attr_init - r_rpc_map = dict((rpc_num, rpc_fun) for rpc_fun, rpc_num in rm.rpc_map.items()) - return funcdef, r_rpc_map + r_rpc_map = dict((rpc_num, rpc_fun) + for rpc_fun, rpc_num in rm.rpc_map.items()) + return funcdef, r_rpc_map diff --git a/artiq/compiler/interleave.py b/artiq/compiler/interleave.py index 07947a2b3..1494e734b 100644 --- a/artiq/compiler/interleave.py +++ b/artiq/compiler/interleave.py @@ -1,105 +1,113 @@ -import ast, types +import ast +import types from artiq.compiler.tools import * + # -1 statement duration could not be pre-determined # 0 statement has no effect on timeline # >0 statement is a static delay that advances the timeline # by the given amount (in microcycles) def _get_duration(stmt): - if isinstance(stmt, (ast.Expr, ast.Assign)): - return _get_duration(stmt.value) - elif isinstance(stmt, ast.If): - if all(_get_duration(s) == 0 for s in stmt.body) and all(_get_duration(s) == 0 for s in stmt.orelse): - return 0 - else: - return -1 - elif isinstance(stmt, ast.Call) and isinstance(stmt.func, ast.Name): - name = stmt.func.id - if name == "delay": - try: - da = eval_constant(stmt.args[0]) - except NotConstant: - da = -1 - return da - else: - return 0 - else: - return 0 + if isinstance(stmt, (ast.Expr, ast.Assign)): + return _get_duration(stmt.value) + elif isinstance(stmt, ast.If): + if (all(_get_duration(s) == 0 for s in stmt.body) + and all(_get_duration(s) == 0 for s in stmt.orelse)): + return 0 + else: + return -1 + elif isinstance(stmt, ast.Call) and isinstance(stmt.func, ast.Name): + name = stmt.func.id + if name == "delay": + try: + da = eval_constant(stmt.args[0]) + except NotConstant: + da = -1 + return da + else: + return 0 + else: + return 0 + def _interleave_timelines(timelines): - r = [] + r = [] - current_stmts = [] - for stmts in timelines: - it = iter(stmts) - try: - stmt = next(it) - except StopIteration: - pass - else: - current_stmts.append(types.SimpleNamespace(delay=_get_duration(stmt), stmt=stmt, it=it)) + current_stmts = [] + for stmts in timelines: + it = iter(stmts) + try: + stmt = next(it) + except StopIteration: + pass + else: + current_stmts.append(types.SimpleNamespace( + delay=_get_duration(stmt), stmt=stmt, it=it)) - while current_stmts: - dt = min(stmt.delay for stmt in current_stmts) - if dt < 0: - # contains statement(s) with indeterminate duration - return None - if dt > 0: - # advance timeline by dt - for stmt in current_stmts: - stmt.delay -= dt - if stmt.delay == 0: - ref_stmt = stmt.stmt - delay_stmt = ast.copy_location( - ast.Expr(ast.Call(func=ast.Name("delay", ast.Load()), - args=[value_to_ast(dt)], - keywords=[], starargs=[], kwargs=[])), - ref_stmt) - r.append(delay_stmt) - else: - for stmt in current_stmts: - if stmt.delay == 0: - r.append(stmt.stmt) - # discard executed statements - exhausted_list = [] - for stmt_i, stmt in enumerate(current_stmts): - if stmt.delay == 0: - try: - stmt.stmt = next(stmt.it) - except StopIteration: - exhausted_list.append(stmt_i) - else: - stmt.delay = _get_duration(stmt.stmt) - for offset, i in enumerate(exhausted_list): - current_stmts.pop(i-offset) + while current_stmts: + dt = min(stmt.delay for stmt in current_stmts) + if dt < 0: + # contains statement(s) with indeterminate duration + return None + if dt > 0: + # advance timeline by dt + for stmt in current_stmts: + stmt.delay -= dt + if stmt.delay == 0: + ref_stmt = stmt.stmt + delay_stmt = ast.copy_location( + ast.Expr(ast.Call( + func=ast.Name("delay", ast.Load()), + args=[value_to_ast(dt)], + keywords=[], starargs=[], kwargs=[])), + ref_stmt) + r.append(delay_stmt) + else: + for stmt in current_stmts: + if stmt.delay == 0: + r.append(stmt.stmt) + # discard executed statements + exhausted_list = [] + for stmt_i, stmt in enumerate(current_stmts): + if stmt.delay == 0: + try: + stmt.stmt = next(stmt.it) + except StopIteration: + exhausted_list.append(stmt_i) + else: + stmt.delay = _get_duration(stmt.stmt) + for offset, i in enumerate(exhausted_list): + current_stmts.pop(i-offset) + + return r - return r def _interleave_stmts(stmts): - replacements = [] - for stmt_i, stmt in enumerate(stmts): - if isinstance(stmt, (ast.For, ast.While, ast.If)): - _interleave_stmts(stmt.body) - _interleave_stmts(stmt.orelse) - elif isinstance(stmt, ast.With): - btype = stmt.items[0].context_expr.id - if btype == "sequential": - _interleave_stmts(stmt.body) - replacements.append((stmt_i, stmt.body)) - elif btype == "parallel": - timelines = [[s] for s in stmt.body] - for timeline in timelines: - _interleave_stmts(timeline) - merged = _interleave_timelines(timelines) - if merged is not None: - replacements.append((stmt_i, merged)) - else: - raise ValueError("Unknown block type: " + btype) - offset = 0 - for location, new_stmts in replacements: - stmts[offset+location:offset+location+1] = new_stmts - offset += len(new_stmts) - 1 + replacements = [] + for stmt_i, stmt in enumerate(stmts): + if isinstance(stmt, (ast.For, ast.While, ast.If)): + _interleave_stmts(stmt.body) + _interleave_stmts(stmt.orelse) + elif isinstance(stmt, ast.With): + btype = stmt.items[0].context_expr.id + if btype == "sequential": + _interleave_stmts(stmt.body) + replacements.append((stmt_i, stmt.body)) + elif btype == "parallel": + timelines = [[s] for s in stmt.body] + for timeline in timelines: + _interleave_stmts(timeline) + merged = _interleave_timelines(timelines) + if merged is not None: + replacements.append((stmt_i, merged)) + else: + raise ValueError("Unknown block type: " + btype) + offset = 0 + for location, new_stmts in replacements: + stmts[offset+location:offset+location+1] = new_stmts + offset += len(new_stmts) - 1 + def interleave(funcdef): - _interleave_stmts(funcdef.body) + _interleave_stmts(funcdef.body) diff --git a/artiq/compiler/ir.py b/artiq/compiler/ir.py index 23a7b2cba..6dd997f26 100644 --- a/artiq/compiler/ir.py +++ b/artiq/compiler/ir.py @@ -3,32 +3,34 @@ from llvm import passes as lp from artiq.compiler import ir_infer_types, ir_ast_body, ir_values -def compile_function(module, env, funcdef): - function_type = lc.Type.function(lc.Type.void(), []) - function = module.add_function(function_type, funcdef.name) - bb = function.append_basic_block("entry") - builder = lc.Builder.new(bb) - ns = ir_infer_types.infer_types(env, funcdef) - for k, v in ns.items(): - v.alloca(builder, k) - visitor = ir_ast_body.Visitor(env, ns, builder) - visitor.visit_statements(funcdef.body) - builder.ret_void() +def compile_function(module, env, funcdef): + function_type = lc.Type.function(lc.Type.void(), []) + function = module.add_function(function_type, funcdef.name) + bb = function.append_basic_block("entry") + builder = lc.Builder.new(bb) + + ns = ir_infer_types.infer_types(env, funcdef) + for k, v in ns.items(): + v.alloca(builder, k) + visitor = ir_ast_body.Visitor(env, ns, builder) + visitor.visit_statements(funcdef.body) + builder.ret_void() + def get_runtime_binary(env, funcdef): - module = lc.Module.new("main") - env.init_module(module) - ir_values.init_module(module) + module = lc.Module.new("main") + env.init_module(module) + ir_values.init_module(module) - compile_function(module, env, funcdef) + compile_function(module, env, funcdef) - pass_manager = lp.PassManager.new() - pass_manager.add(lp.PASS_MEM2REG) - pass_manager.add(lp.PASS_INSTCOMBINE) - pass_manager.add(lp.PASS_REASSOCIATE) - pass_manager.add(lp.PASS_GVN) - pass_manager.add(lp.PASS_SIMPLIFYCFG) - pass_manager.run(module) + pass_manager = lp.PassManager.new() + pass_manager.add(lp.PASS_MEM2REG) + pass_manager.add(lp.PASS_INSTCOMBINE) + pass_manager.add(lp.PASS_REASSOCIATE) + pass_manager.add(lp.PASS_GVN) + pass_manager.add(lp.PASS_SIMPLIFYCFG) + pass_manager.run(module) - return env.emit_object() + return env.emit_object() diff --git a/artiq/compiler/ir_ast_body.py b/artiq/compiler/ir_ast_body.py index 3437c1b56..b4dc901f7 100644 --- a/artiq/compiler/ir_ast_body.py +++ b/artiq/compiler/ir_ast_body.py @@ -2,187 +2,203 @@ import ast from artiq.compiler import ir_values + class Visitor: - def __init__(self, env, ns, builder=None): - self.env = env - self.ns = ns - self.builder = builder + def __init__(self, env, ns, builder=None): + self.env = env + self.ns = ns + self.builder = builder - # builder can be None for visit_expression - def visit_expression(self, node): - method = "_visit_expr_" + node.__class__.__name__ - try: - visitor = getattr(self, method) - except AttributeError: - raise NotImplementedError("Unsupported node '{}' in expression".format(node.__class__.__name__)) - return visitor(node) + # builder can be None for visit_expression + def visit_expression(self, node): + method = "_visit_expr_" + node.__class__.__name__ + try: + visitor = getattr(self, method) + except AttributeError: + raise NotImplementedError("Unsupported node '{}' in expression" + .format(node.__class__.__name__)) + return visitor(node) - def _visit_expr_Name(self, node): - try: - r = self.ns[node.id] - except KeyError: - raise NameError("Name '{}' is not defined".format(node.id)) - return r + def _visit_expr_Name(self, node): + try: + r = self.ns[node.id] + except KeyError: + raise NameError("Name '{}' is not defined".format(node.id)) + return r - def _visit_expr_NameConstant(self, node): - v = node.value - if v is None: - r = ir_values.VNone() - elif isinstance(v, bool): - r = ir_values.VBool() - else: - raise NotImplementedError - if self.builder is not None: - r.set_const_value(self.builder, v) - return r + def _visit_expr_NameConstant(self, node): + v = node.value + if v is None: + r = ir_values.VNone() + elif isinstance(v, bool): + r = ir_values.VBool() + else: + raise NotImplementedError + if self.builder is not None: + r.set_const_value(self.builder, v) + return r - def _visit_expr_Num(self, node): - n = node.n - if isinstance(n, int): - if abs(n) < 2**31: - r = ir_values.VInt() - else: - r = ir_values.VInt(64) - else: - raise NotImplementedError - if self.builder is not None: - r.set_const_value(self.builder, n) - return r + def _visit_expr_Num(self, node): + n = node.n + if isinstance(n, int): + if abs(n) < 2**31: + r = ir_values.VInt() + else: + r = ir_values.VInt(64) + else: + raise NotImplementedError + if self.builder is not None: + r.set_const_value(self.builder, n) + return r - def _visit_expr_UnaryOp(self, node): - ast_unops = { - ast.Invert: ir_values.operators.inv, - ast.Not: ir_values.operators.not_, - ast.UAdd: ir_values.operators.pos, - ast.USub: ir_values.operators.neg - } - return ast_unops[type(node.op)](self.visit_expression(node.operand), self.builder) + def _visit_expr_UnaryOp(self, node): + ast_unops = { + ast.Invert: ir_values.operators.inv, + ast.Not: ir_values.operators.not_, + ast.UAdd: ir_values.operators.pos, + ast.USub: ir_values.operators.neg + } + return ast_unops[type(node.op)](self.visit_expression(node.operand), + self.builder) - def _visit_expr_BinOp(self, node): - ast_binops = { - ast.Add: ir_values.operators.add, - ast.Sub: ir_values.operators.sub, - ast.Mult: ir_values.operators.mul, - ast.Div: ir_values.operators.truediv, - ast.FloorDiv: ir_values.operators.floordiv, - ast.Mod: ir_values.operators.mod, - ast.Pow: ir_values.operators.pow, - ast.LShift: ir_values.operators.lshift, - ast.RShift: ir_values.operators.rshift, - ast.BitOr: ir_values.operators.or_, - ast.BitXor: ir_values.operators.xor, - ast.BitAnd: ir_values.operators.and_ - } - return ast_binops[type(node.op)](self.visit_expression(node.left), self.visit_expression(node.right), self.builder) + def _visit_expr_BinOp(self, node): + ast_binops = { + ast.Add: ir_values.operators.add, + ast.Sub: ir_values.operators.sub, + ast.Mult: ir_values.operators.mul, + ast.Div: ir_values.operators.truediv, + ast.FloorDiv: ir_values.operators.floordiv, + ast.Mod: ir_values.operators.mod, + ast.Pow: ir_values.operators.pow, + ast.LShift: ir_values.operators.lshift, + ast.RShift: ir_values.operators.rshift, + ast.BitOr: ir_values.operators.or_, + ast.BitXor: ir_values.operators.xor, + ast.BitAnd: ir_values.operators.and_ + } + return ast_binops[type(node.op)](self.visit_expression(node.left), + self.visit_expression(node.right), + self.builder) - def _visit_expr_Compare(self, node): - ast_cmps = { - ast.Eq: ir_values.operators.eq, - ast.NotEq: ir_values.operators.ne, - ast.Lt: ir_values.operators.lt, - ast.LtE: ir_values.operators.le, - ast.Gt: ir_values.operators.gt, - ast.GtE: ir_values.operators.ge - } - comparisons = [] - old_comparator = self.visit_expression(node.left) - for op, comparator_a in zip(node.ops, node.comparators): - comparator = self.visit_expression(comparator_a) - comparison = ast_cmps[type(op)](old_comparator, comparator, self.builder) - comparisons.append(comparison) - old_comparator = comparator - r = comparisons[0] - for comparison in comparisons[1:]: - r = ir_values.operators.and_(r, comparison) - return r + def _visit_expr_Compare(self, node): + ast_cmps = { + ast.Eq: ir_values.operators.eq, + ast.NotEq: ir_values.operators.ne, + ast.Lt: ir_values.operators.lt, + ast.LtE: ir_values.operators.le, + ast.Gt: ir_values.operators.gt, + ast.GtE: ir_values.operators.ge + } + comparisons = [] + old_comparator = self.visit_expression(node.left) + for op, comparator_a in zip(node.ops, node.comparators): + comparator = self.visit_expression(comparator_a) + comparison = ast_cmps[type(op)](old_comparator, comparator, + self.builder) + comparisons.append(comparison) + old_comparator = comparator + r = comparisons[0] + for comparison in comparisons[1:]: + r = ir_values.operators.and_(r, comparison) + return r - def _visit_expr_Call(self, node): - ast_unfuns = { - "bool": ir_values.operators.bool, - "int": ir_values.operators.int, - "int64": ir_values.operators.int64, - "round": ir_values.operators.round, - "round64": ir_values.operators.round64, - } - fn = node.func.id - if fn in ast_unfuns: - return ast_unfuns[fn](self.visit_expression(node.args[0]), self.builder) - elif fn == "Fraction": - r = ir_values.VFraction() - if self.builder is not None: - numerator = self.visit_expression(node.args[0]) - denominator = self.visit_expression(node.args[1]) - r.set_value_nd(self.builder, numerator, denominator) - return r - elif fn == "syscall": - return self.env.syscall(node.args[0].s, - [self.visit_expression(expr) for expr in node.args[1:]], - self.builder) - else: - raise NameError("Function '{}' is not defined".format(fn)) + def _visit_expr_Call(self, node): + ast_unfuns = { + "bool": ir_values.operators.bool, + "int": ir_values.operators.int, + "int64": ir_values.operators.int64, + "round": ir_values.operators.round, + "round64": ir_values.operators.round64, + } + fn = node.func.id + if fn in ast_unfuns: + return ast_unfuns[fn](self.visit_expression(node.args[0]), + self.builder) + elif fn == "Fraction": + r = ir_values.VFraction() + if self.builder is not None: + numerator = self.visit_expression(node.args[0]) + denominator = self.visit_expression(node.args[1]) + r.set_value_nd(self.builder, numerator, denominator) + return r + elif fn == "syscall": + return self.env.syscall( + node.args[0].s, + [self.visit_expression(expr) for expr in node.args[1:]], + self.builder) + else: + raise NameError("Function '{}' is not defined".format(fn)) - def visit_statements(self, stmts): - for node in stmts: - method = "_visit_stmt_" + node.__class__.__name__ - try: - visitor = getattr(self, method) - except AttributeError: - raise NotImplementedError("Unsupported node '{}' in statement".format(node.__class__.__name__)) - visitor(node) + def visit_statements(self, stmts): + for node in stmts: + method = "_visit_stmt_" + node.__class__.__name__ + try: + visitor = getattr(self, method) + except AttributeError: + raise NotImplementedError("Unsupported node '{}' in statement" + .format(node.__class__.__name__)) + visitor(node) - def _visit_stmt_Assign(self, node): - val = self.visit_expression(node.value) - for target in node.targets: - if isinstance(target, ast.Name): - self.ns[target.id].set_value(self.builder, val) - else: - raise NotImplementedError + def _visit_stmt_Assign(self, node): + val = self.visit_expression(node.value) + for target in node.targets: + if isinstance(target, ast.Name): + self.ns[target.id].set_value(self.builder, val) + else: + raise NotImplementedError - def _visit_stmt_AugAssign(self, node): - val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) - if isinstance(node.target, ast.Name): - self.ns[node.target.id].set_value(self.builder, val) - else: - raise NotImplementedError + def _visit_stmt_AugAssign(self, node): + val = self.visit_expression(ast.BinOp(op=node.op, left=node.target, + right=node.value)) + if isinstance(node.target, ast.Name): + self.ns[node.target.id].set_value(self.builder, val) + else: + raise NotImplementedError - def _visit_stmt_Expr(self, node): - self.visit_expression(node.value) + def _visit_stmt_Expr(self, node): + self.visit_expression(node.value) - def _visit_stmt_If(self, node): - function = self.builder.basic_block.function - then_block = function.append_basic_block("i_then") - else_block = function.append_basic_block("i_else") - merge_block = function.append_basic_block("i_merge") + def _visit_stmt_If(self, node): + function = self.builder.basic_block.function + then_block = function.append_basic_block("i_then") + else_block = function.append_basic_block("i_else") + merge_block = function.append_basic_block("i_merge") - condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) - self.builder.cbranch(condition.get_ssa_value(self.builder), then_block, else_block) + condition = ir_values.operators.bool(self.visit_expression(node.test), + self.builder) + self.builder.cbranch(condition.get_ssa_value(self.builder), + then_block, else_block) - self.builder.position_at_end(then_block) - self.visit_statements(node.body) - self.builder.branch(merge_block) + self.builder.position_at_end(then_block) + self.visit_statements(node.body) + self.builder.branch(merge_block) - self.builder.position_at_end(else_block) - self.visit_statements(node.orelse) - self.builder.branch(merge_block) + self.builder.position_at_end(else_block) + self.visit_statements(node.orelse) + self.builder.branch(merge_block) - self.builder.position_at_end(merge_block) + self.builder.position_at_end(merge_block) - def _visit_stmt_While(self, node): - function = self.builder.basic_block.function - body_block = function.append_basic_block("w_body") - else_block = function.append_basic_block("w_else") - merge_block = function.append_basic_block("w_merge") + def _visit_stmt_While(self, node): + function = self.builder.basic_block.function + body_block = function.append_basic_block("w_body") + else_block = function.append_basic_block("w_else") + merge_block = function.append_basic_block("w_merge") - condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) - self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, else_block) + condition = ir_values.operators.bool( + self.visit_expression(node.test), self.builder) + self.builder.cbranch( + condition.get_ssa_value(self.builder), body_block, else_block) - self.builder.position_at_end(body_block) - self.visit_statements(node.body) - condition = ir_values.operators.bool(self.visit_expression(node.test), self.builder) - self.builder.cbranch(condition.get_ssa_value(self.builder), body_block, merge_block) + self.builder.position_at_end(body_block) + self.visit_statements(node.body) + condition = ir_values.operators.bool( + self.visit_expression(node.test), self.builder) + self.builder.cbranch( + condition.get_ssa_value(self.builder), body_block, merge_block) - self.builder.position_at_end(else_block) - self.visit_statements(node.orelse) - self.builder.branch(merge_block) + self.builder.position_at_end(else_block) + self.visit_statements(node.orelse) + self.builder.branch(merge_block) - self.builder.position_at_end(merge_block) + self.builder.position_at_end(merge_block) diff --git a/artiq/compiler/ir_infer_types.py b/artiq/compiler/ir_infer_types.py index 6f40e4e8e..7d385a106 100644 --- a/artiq/compiler/ir_infer_types.py +++ b/artiq/compiler/ir_infer_types.py @@ -4,46 +4,49 @@ from copy import deepcopy from artiq.compiler.ir_ast_body import Visitor + class _TypeScanner(ast.NodeVisitor): - def __init__(self, env, ns): - self.exprv = Visitor(env, ns) + def __init__(self, env, ns): + self.exprv = Visitor(env, ns) - def visit_Assign(self, node): - val = self.exprv.visit_expression(node.value) - ns = self.exprv.ns - for target in node.targets: - if isinstance(target, ast.Name): - if target.id in ns: - ns[target.id].merge(val) - else: - ns[target.id] = val - else: - raise NotImplementedError + def visit_Assign(self, node): + val = self.exprv.visit_expression(node.value) + ns = self.exprv.ns + for target in node.targets: + if isinstance(target, ast.Name): + if target.id in ns: + ns[target.id].merge(val) + else: + ns[target.id] = val + else: + raise NotImplementedError + + def visit_AugAssign(self, node): + val = self.exprv.visit_expression(ast.BinOp( + op=node.op, left=node.target, right=node.value)) + ns = self.exprv.ns + target = node.target + if isinstance(target, ast.Name): + if target.id in ns: + ns[target.id].merge(val) + else: + ns[target.id] = val + else: + raise NotImplementedError - def visit_AugAssign(self, node): - val = self.exprv.visit_expression(ast.BinOp(op=node.op, left=node.target, right=node.value)) - ns = self.exprv.ns - target = node.target - if isinstance(target, ast.Name): - if target.id in ns: - ns[target.id].merge(val) - else: - ns[target.id] = val - else: - raise NotImplementedError def infer_types(env, node): - ns = dict() - while True: - prev_ns = deepcopy(ns) - ts = _TypeScanner(env, ns) - ts.visit(node) - if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()): - # no more promotions - completed - return ns + ns = dict() + while True: + prev_ns = deepcopy(ns) + ts = _TypeScanner(env, ns) + ts.visit(node) + if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()): + # no more promotions - completed + return ns if __name__ == "__main__": - testcode = """ + testcode = """ a = 2 # promoted later to int64 b = a + 1 # initially int32, becomes int64 after a is promoted c = b//2 # initially int32, becomes int64 after b is promoted @@ -53,6 +56,6 @@ a += x # promotes a to int64 foo = True bar = None """ - ns = infer_types(None, ast.parse(testcode)) - for k, v in sorted(ns.items(), key=itemgetter(0)): - print("{:10}--> {}".format(k, str(v))) + ns = infer_types(None, ast.parse(testcode)) + for k, v in sorted(ns.items(), key=itemgetter(0)): + print("{:10}--> {}".format(k, str(v))) diff --git a/artiq/compiler/ir_values.py b/artiq/compiler/ir_values.py index ac34d3c3f..1b3a8e6fb 100644 --- a/artiq/compiler/ir_values.py +++ b/artiq/compiler/ir_values.py @@ -2,404 +2,450 @@ from types import SimpleNamespace from llvm import core as lc + class _Value: - def __init__(self): - self._llvm_value = None + def __init__(self): + self._llvm_value = None - def get_ssa_value(self, builder): - if isinstance(self._llvm_value, lc.AllocaInstruction): - return builder.load(self._llvm_value) - else: - return self._llvm_value + def get_ssa_value(self, builder): + if isinstance(self._llvm_value, lc.AllocaInstruction): + return builder.load(self._llvm_value) + else: + return self._llvm_value - def set_ssa_value(self, builder, value): - if self._llvm_value is None: - self._llvm_value = value - elif isinstance(self._llvm_value, lc.AllocaInstruction): - builder.store(value, self._llvm_value) - else: - raise RuntimeError("Attempted to set LLVM SSA value multiple times") + def set_ssa_value(self, builder, value): + if self._llvm_value is None: + self._llvm_value = value + elif isinstance(self._llvm_value, lc.AllocaInstruction): + builder.store(value, self._llvm_value) + else: + raise RuntimeError( + "Attempted to set LLVM SSA value multiple times") - def alloca(self, builder, name): - if self._llvm_value is not None: - raise RuntimeError("Attempted to alloca existing LLVM value") - self._llvm_value = builder.alloca(self.get_llvm_type(), name=name) + def alloca(self, builder, name): + if self._llvm_value is not None: + raise RuntimeError("Attempted to alloca existing LLVM value") + self._llvm_value = builder.alloca(self.get_llvm_type(), name=name) - def o_int(self, builder): - return self.o_intx(32, builder) + def o_int(self, builder): + return self.o_intx(32, builder) - def o_int64(self, builder): - return self.o_intx(64, builder) + def o_int64(self, builder): + return self.o_intx(64, builder) - def o_round(self, builder): - return self.o_roundx(32, builder) + def o_round(self, builder): + return self.o_roundx(32, builder) + + def o_round64(self, builder): + return self.o_roundx(64, builder) - def o_round64(self, builder): - return self.o_roundx(64, builder) # None type class VNone(_Value): - def __repr__(self): - return "" + def __repr__(self): + return "" - def get_llvm_type(self): - return lc.Type.void() + def get_llvm_type(self): + return lc.Type.void() - def same_type(self, other): - return isinstance(other, VNone) + def same_type(self, other): + return isinstance(other, VNone) - def merge(self, other): - if not isinstance(other, VNone): - raise TypeError + def merge(self, other): + if not isinstance(other, VNone): + raise TypeError - def alloca(self, builder, name): - pass + def alloca(self, builder, name): + pass + + def o_bool(self, builder): + r = VBool() + if builder is not None: + r.set_const_value(builder, False) + return r - def o_bool(self, builder): - r = VBool() - if builder is not None: - r.set_const_value(builder, False) - return r # Integer type class VInt(_Value): - def __init__(self, nbits=32): - _Value.__init__(self) - self.nbits = nbits + def __init__(self, nbits=32): + _Value.__init__(self) + self.nbits = nbits - def get_llvm_type(self): - return lc.Type.int(self.nbits) + def get_llvm_type(self): + return lc.Type.int(self.nbits) - def __repr__(self): - return "".format(self.nbits) + def __repr__(self): + return "".format(self.nbits) - def same_type(self, other): - return isinstance(other, VInt) and other.nbits == self.nbits + def same_type(self, other): + return isinstance(other, VInt) and other.nbits == self.nbits - def merge(self, other): - if isinstance(other, VInt) and not isinstance(other, VBool): - if other.nbits > self.nbits: - self.nbits = other.nbits - else: - raise TypeError + def merge(self, other): + if isinstance(other, VInt) and not isinstance(other, VBool): + if other.nbits > self.nbits: + self.nbits = other.nbits + else: + raise TypeError - def set_value(self, builder, n): - self.set_ssa_value(builder, n.o_intx(self.nbits, builder).get_ssa_value(builder)) + def set_value(self, builder, n): + self.set_ssa_value( + builder, n.o_intx(self.nbits, builder).get_ssa_value(builder)) - def set_const_value(self, builder, n): - self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n)) + def set_const_value(self, builder, n): + self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n)) - def o_bool(self, builder): - r = VBool() - if builder is not None: - r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, - self.get_ssa_value(builder), lc.Constant.int(self.get_llvm_type(), 0))) - return r + def o_bool(self, builder): + r = VBool() + if builder is not None: + r.set_ssa_value( + builder, builder.icmp( + lc.ICMP_NE, + self.get_ssa_value(builder), + lc.Constant.int(self.get_llvm_type(), 0))) + return r + + def o_intx(self, target_bits, builder): + r = VInt(target_bits) + if builder is not None: + if self.nbits == target_bits: + r.set_ssa_value( + builder, self.get_ssa_value(builder)) + if self.nbits > target_bits: + r.set_ssa_value( + builder, builder.trunc(self.get_ssa_value(builder), + r.get_llvm_type())) + if self.nbits < target_bits: + r.set_ssa_value( + builder, builder.sext(self.get_ssa_value(builder), + r.get_llvm_type())) + return r + o_roundx = o_intx - def o_intx(self, target_bits, builder): - r = VInt(target_bits) - if builder is not None: - if self.nbits == target_bits: - r.set_ssa_value(builder, self.get_ssa_value(builder)) - if self.nbits > target_bits: - r.set_ssa_value(builder, builder.trunc(self.get_ssa_value(builder), r.get_llvm_type())) - if self.nbits < target_bits: - r.set_ssa_value(builder, builder.sext(self.get_ssa_value(builder), r.get_llvm_type())) - return r - o_roundx = o_intx def _make_vint_binop_method(builder_name): - def binop_method(self, other, builder): - if isinstance(other, VInt): - target_bits = max(self.nbits, other.nbits) - r = VInt(target_bits) - if builder is not None: - left = self.o_intx(target_bits, builder) - right = other.o_intx(target_bits, builder) - bf = getattr(builder, builder_name) - r.set_ssa_value(builder, - bf(left.get_ssa_value(builder), right.get_ssa_value(builder))) - return r - else: - return NotImplemented - return binop_method + def binop_method(self, other, builder): + if isinstance(other, VInt): + target_bits = max(self.nbits, other.nbits) + r = VInt(target_bits) + if builder is not None: + left = self.o_intx(target_bits, builder) + right = other.o_intx(target_bits, builder) + bf = getattr(builder, builder_name) + r.set_ssa_value( + builder, bf(left.get_ssa_value(builder), + right.get_ssa_value(builder))) + return r + else: + return NotImplemented + return binop_method + +for _method_name, _builder_name in (("o_add", "add"), + ("o_sub", "sub"), + ("o_mul", "mul"), + ("o_floordiv", "sdiv"), + ("o_mod", "srem"), + ("o_and", "and_"), + ("o_xor", "xor"), + ("o_or", "or_")): + setattr(VInt, _method_name, _make_vint_binop_method(_builder_name)) -for _method_name, _builder_name in ( - ("o_add", "add"), - ("o_sub", "sub"), - ("o_mul", "mul"), - ("o_floordiv", "sdiv"), - ("o_mod", "srem"), - ("o_and", "and_"), - ("o_xor", "xor"), - ("o_or", "or_")): - setattr(VInt, _method_name, _make_vint_binop_method(_builder_name)) def _make_vint_cmp_method(icmp_val): - def cmp_method(self, other, builder): - if isinstance(other, VInt): - r = VBool() - if builder is not None: - target_bits = max(self.nbits, other.nbits) - left = self.o_intx(target_bits, builder) - right = other.o_intx(target_bits, builder) - r.set_ssa_value(builder, - builder.icmp(icmp_val, left.get_ssa_value(builder), right.get_ssa_value(builder))) - return r - else: - return NotImplemented - return cmp_method + def cmp_method(self, other, builder): + if isinstance(other, VInt): + r = VBool() + if builder is not None: + target_bits = max(self.nbits, other.nbits) + left = self.o_intx(target_bits, builder) + right = other.o_intx(target_bits, builder) + r.set_ssa_value( + builder, + builder.icmp( + icmp_val, left.get_ssa_value(builder), + right.get_ssa_value(builder))) + return r + else: + return NotImplemented + return cmp_method + +for _method_name, _icmp_val in (("o_eq", lc.ICMP_EQ), + ("o_ne", lc.ICMP_NE), + ("o_lt", lc.ICMP_SLT), + ("o_le", lc.ICMP_SLE), + ("o_gt", lc.ICMP_SGT), + ("o_ge", lc.ICMP_SGE)): + setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val)) -for _method_name, _icmp_val in ( - ("o_eq", lc.ICMP_EQ), - ("o_ne", lc.ICMP_NE), - ("o_lt", lc.ICMP_SLT), - ("o_le", lc.ICMP_SLE), - ("o_gt", lc.ICMP_SGT), - ("o_ge", lc.ICMP_SGE)): - setattr(VInt, _method_name, _make_vint_cmp_method(_icmp_val)) # Boolean type class VBool(VInt): - def __init__(self): - VInt.__init__(self, 1) + def __init__(self): + VInt.__init__(self, 1) - def __repr__(self): - return "" + def __repr__(self): + return "" - def same_type(self, other): - return isinstance(other, VBool) + def same_type(self, other): + return isinstance(other, VBool) - def merge(self, other): - if not isinstance(other, VBool): - raise TypeError + def merge(self, other): + if not isinstance(other, VBool): + raise TypeError - def set_const_value(self, builder, b): - VInt.set_const_value(self, builder, int(b)) + def set_const_value(self, builder, b): + VInt.set_const_value(self, builder, int(b)) + + def o_bool(self, builder): + r = VBool() + if builder is not None: + r.set_ssa_value(builder, self.get_ssa_value(builder)) + return r - def o_bool(self, builder): - r = VBool() - if builder is not None: - r.set_ssa_value(builder, self.get_ssa_value(builder)) - return r # Fraction type def _gcd64(builder, a, b): - gcd_f = builder.module.get_function_named("__gcd64") - return builder.call(gcd_f, [a, b]) + gcd_f = builder.module.get_function_named("__gcd64") + return builder.call(gcd_f, [a, b]) + def _frac_normalize(builder, numerator, denominator): - gcd = _gcd64(numerator, denominator) - numerator = builder.sdiv(numerator, gcd) - denominator = builder.sdiv(denominator, gcd) - return numerator, denominator + gcd = _gcd64(numerator, denominator) + numerator = builder.sdiv(numerator, gcd) + denominator = builder.sdiv(denominator, gcd) + return numerator, denominator + def _frac_make_ssa(builder, numerator, denominator): - value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2)) - value = builder.insert_element(value, numerator, lc.Constant.int(lc.Type.int(), 0)) - value = builder.insert_element(value, denominator, lc.Constant.int(lc.Type.int(), 1)) - return value + value = lc.Constant.undef(lc.Type.vector(lc.Type.int(64), 2)) + value = builder.insert_element( + value, numerator, lc.Constant.int(lc.Type.int(), 0)) + value = builder.insert_element( + value, denominator, lc.Constant.int(lc.Type.int(), 1)) + return value + class VFraction(_Value): - def get_llvm_type(self): - return lc.Type.vector(lc.Type.int(64), 2) + def get_llvm_type(self): + return lc.Type.vector(lc.Type.int(64), 2) - def __repr__(self): - return "" + def __repr__(self): + return "" - def same_type(self, other): - return isinstance(other, VFraction) + def same_type(self, other): + return isinstance(other, VFraction) - def merge(self, other): - if not isinstance(other, VFraction): - raise TypeError + def merge(self, other): + if not isinstance(other, VFraction): + raise TypeError - def _nd(self, builder, invert=False): - ssa_value = self.get_ssa_value(builder) - numerator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 0)) - denominator = builder.extract_element(ssa_value, lc.Constant.int(lc.Type.int(), 1)) - if invert: - return denominator, numerator - else: - return numerator, denominator + def _nd(self, builder, invert=False): + ssa_value = self.get_ssa_value(builder) + numerator = builder.extract_element( + ssa_value, lc.Constant.int(lc.Type.int(), 0)) + denominator = builder.extract_element( + ssa_value, lc.Constant.int(lc.Type.int(), 1)) + if invert: + return denominator, numerator + else: + return numerator, denominator - def set_value_nd(self, builder, numerator, denominator): - numerator = numerator.o_int64(builder).get_ssa_value(builder) - denominator = denominator.o_int64(builder).get_ssa_value(builder) - numerator, denominator = _frac_normalize(builder, numerator, denominator) - self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator)) + def set_value_nd(self, builder, numerator, denominator): + numerator = numerator.o_int64(builder).get_ssa_value(builder) + denominator = denominator.o_int64(builder).get_ssa_value(builder) + numerator, denominator = _frac_normalize( + builder, numerator, denominator) + self.set_ssa_value( + builder, _frac_make_ssa(builder, numerator, denominator)) - def set_value(self, builder, n): - if not isinstance(n, VFraction): - raise TypeError - self.set_ssa_value(builder, n.get_ssa_value(builder)) + def set_value(self, builder, n): + if not isinstance(n, VFraction): + raise TypeError + self.set_ssa_value(builder, n.get_ssa_value(builder)) - def o_bool(self, builder): - r = VBool() - if builder is not None: - zero = lc.Constant.int(lc.Type.int(64), 0) - numerator = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0)) - r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero)) - return r + def o_bool(self, builder): + r = VBool() + if builder is not None: + zero = lc.Constant.int(lc.Type.int(64), 0) + numerator = builder.extract_element( + self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), 0)) + r.set_ssa_value(builder, builder.icmp(lc.ICMP_NE, numerator, zero)) + return r - def o_intx(self, target_bits, builder): - if builder is None: - return VInt(target_bits) - else: - r = VInt(64) - numerator, denominator = self._nd(builder) - r.set_ssa_value(builder, builder.sdiv(numerator, denominator)) - return r.o_intx(target_bits, builder) + def o_intx(self, target_bits, builder): + if builder is None: + return VInt(target_bits) + else: + r = VInt(64) + numerator, denominator = self._nd(builder) + r.set_ssa_value(builder, builder.sdiv(numerator, denominator)) + return r.o_intx(target_bits, builder) - def o_roundx(self, target_bits, builder): - if builder is None: - return VInt(target_bits) - else: - r = VInt(64) - numerator, denominator = self._nd(builder) - h_denominator = builder.ashr(denominator, lc.Constant.int(lc.Type.int(), 1)) - r_numerator = builder.add(numerator, h_denominator) - r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator)) - return r.o_intx(target_bits, builder) + def o_roundx(self, target_bits, builder): + if builder is None: + return VInt(target_bits) + else: + r = VInt(64) + numerator, denominator = self._nd(builder) + h_denominator = builder.ashr(denominator, + lc.Constant.int(lc.Type.int(), 1)) + r_numerator = builder.add(numerator, h_denominator) + r.set_ssa_value(builder, builder.sdiv(r_numerator, denominator)) + return r.o_intx(target_bits, builder) - def _o_eq_inv(self, other, builder, ne): - if isinstance(other, VFraction): - r = VBool() - if builder is not None: - ee = [] - for i in range(2): - es = builder.extract_element(self.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), i)) - eo = builder.extract_element(other.get_ssa_value(builder), lc.Constant.int(lc.Type.int(), i)) - ee.append(builder.icmp(lc.ICMP_EQ, es, eo)) - ssa_r = builder.and_(ee[0], ee[1]) - if ne: - ssa_r = builder.xor(ssa_r, lc.Constant.int(lc.Type.int(1), 1)) - r.set_ssa_value(builder, ssa_r) - return r - else: - return NotImplemented + def _o_eq_inv(self, other, builder, ne): + if isinstance(other, VFraction): + r = VBool() + if builder is not None: + ee = [] + for i in range(2): + es = builder.extract_element( + self.get_ssa_value(builder), + lc.Constant.int(lc.Type.int(), i)) + eo = builder.extract_element( + other.get_ssa_value(builder), + lc.Constant.int(lc.Type.int(), i)) + ee.append(builder.icmp(lc.ICMP_EQ, es, eo)) + ssa_r = builder.and_(ee[0], ee[1]) + if ne: + ssa_r = builder.xor(ssa_r, + lc.Constant.int(lc.Type.int(1), 1)) + r.set_ssa_value(builder, ssa_r) + return r + else: + return NotImplemented - def o_eq(self, other, builder): - return self._o_eq_inv(other, builder, False) + def o_eq(self, other, builder): + return self._o_eq_inv(other, builder, False) - def o_ne(self, other, builder): - return self._o_eq_inv(other, builder, True) + def o_ne(self, other, builder): + return self._o_eq_inv(other, builder, True) - def _o_muldiv(self, other, builder, div, invert=False): - r = VFraction() - if isinstance(other, VInt): - if builder is None: - return r - else: - numerator, denominator = self._nd(builder, invert) - i = other.get_ssa_value(builder) - if div: - gcd = _gcd64(i, numerator) - i = builder.sdiv(i, gcd) - numerator = builder.sdiv(numerator, gcd) - denominator = builder.mul(denominator, i) - else: - gcd = _gcd64(i, denominator) - i = builder.sdiv(i, gcd) - denominator = builder.sdiv(denominator, gcd) - numerator = builder.mul(numerator, i) - self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator)) - elif isinstance(other, VFraction): - if builder is None: - return r - else: - numerator, denominator = self._nd(builder, invert) - onumerator, odenominator = other._nd(builder) - if div: - numerator = builder.mul(numerator, odenominator) - denominator = builder.mul(denominator, onumerator) - else: - numerator = builder.mul(numerator, onumerator) - denominator = builder.mul(denominator, odenominator) - numerator, denominator = _frac_normalize(builder, numerator, denominator) - self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, denominator)) - else: - return NotImplemented + def _o_muldiv(self, other, builder, div, invert=False): + r = VFraction() + if isinstance(other, VInt): + if builder is None: + return r + else: + numerator, denominator = self._nd(builder, invert) + i = other.get_ssa_value(builder) + if div: + gcd = _gcd64(i, numerator) + i = builder.sdiv(i, gcd) + numerator = builder.sdiv(numerator, gcd) + denominator = builder.mul(denominator, i) + else: + gcd = _gcd64(i, denominator) + i = builder.sdiv(i, gcd) + denominator = builder.sdiv(denominator, gcd) + numerator = builder.mul(numerator, i) + self.set_ssa_value(builder, _frac_make_ssa(builder, numerator, + denominator)) + elif isinstance(other, VFraction): + if builder is None: + return r + else: + numerator, denominator = self._nd(builder, invert) + onumerator, odenominator = other._nd(builder) + if div: + numerator = builder.mul(numerator, odenominator) + denominator = builder.mul(denominator, onumerator) + else: + numerator = builder.mul(numerator, onumerator) + denominator = builder.mul(denominator, odenominator) + numerator, denominator = _frac_normalize(builder, numerator, + denominator) + self.set_ssa_value( + builder, _frac_make_ssa(builder, numerator, denominator)) + else: + return NotImplemented - def o_mul(self, other, builder): - return self._o_muldiv(other, builder, False) + def o_mul(self, other, builder): + return self._o_muldiv(other, builder, False) - def o_truediv(self, other, builder): - return self._o_muldiv(other, builder, True) + def o_truediv(self, other, builder): + return self._o_muldiv(other, builder, True) - def or_mul(self, other, builder): - return self._o_muldiv(other, builder, False) + def or_mul(self, other, builder): + return self._o_muldiv(other, builder, False) - def or_truediv(self, other, builder): - return self._o_muldiv(other, builder, False, True) + def or_truediv(self, other, builder): + return self._o_muldiv(other, builder, False, True) - def o_floordiv(self, other, builder): - r = self.o_truediv(other, builder) - if r is NotImplemented: - return r - else: - return r.o_int(builder) + def o_floordiv(self, other, builder): + r = self.o_truediv(other, builder) + if r is NotImplemented: + return r + else: + return r.o_int(builder) + + def or_floordiv(self, other, builder): + r = self.or_truediv(other, builder) + if r is NotImplemented: + return r + else: + return r.o_int(builder) - def or_floordiv(self, other, builder): - r = self.or_truediv(other, builder) - if r is NotImplemented: - return r - else: - return r.o_int(builder) # Operators def _make_unary_operator(op_name): - def op(x, builder): - try: - opf = getattr(x, "o_"+op_name) - except AttributeError: - raise TypeError("Unsupported operand type for {}: {}".format(op_name, type(x).__name__)) - return opf(builder) - return op + def op(x, builder): + try: + opf = getattr(x, "o_"+op_name) + except AttributeError: + raise TypeError( + "Unsupported operand type for {}: {}" + .format(op_name, type(x).__name__)) + return opf(builder) + return op + def _make_binary_operator(op_name): - def op(l, r, builder): - try: - opf = getattr(l, "o_"+op_name) - except AttributeError: - result = NotImplemented - else: - result = opf(r, builder) - if result is NotImplemented: - try: - ropf = getattr(r, "or_"+op_name) - except AttributeError: - result = NotImplemented - else: - result = ropf(l, builder) - if result is NotImplemented: - raise TypeError("Unsupported operand types for {}: {} and {}".format( - op_name, type(l).__name__, type(r).__name__)) - return result - return op + def op(l, r, builder): + try: + opf = getattr(l, "o_"+op_name) + except AttributeError: + result = NotImplemented + else: + result = opf(r, builder) + if result is NotImplemented: + try: + ropf = getattr(r, "or_"+op_name) + except AttributeError: + result = NotImplemented + else: + result = ropf(l, builder) + if result is NotImplemented: + raise TypeError( + "Unsupported operand types for {}: {} and {}" + .format(op_name, type(l).__name__, type(r).__name__)) + return result + return op + def _make_operators(): - d = dict() - for op_name in ("bool", "int", "int64", "round", "round64", "inv", "pos", "neg"): - d[op_name] = _make_unary_operator(op_name) - d["not_"] = _make_binary_operator("not") - for op_name in ("add", "sub", "mul", - "truediv", "floordiv", "mod", - "pow", "lshift", "rshift", "xor", - "eq", "ne", "lt", "le", "gt", "ge"): - d[op_name] = _make_binary_operator(op_name) - d["and_"] = _make_binary_operator("and") - d["or_"] = _make_binary_operator("or") - return SimpleNamespace(**d) + d = dict() + for op_name in ("bool", "int", "int64", "round", "round64", + "inv", "pos", "neg"): + d[op_name] = _make_unary_operator(op_name) + d["not_"] = _make_binary_operator("not") + for op_name in ("add", "sub", "mul", + "truediv", "floordiv", "mod", + "pow", "lshift", "rshift", "xor", + "eq", "ne", "lt", "le", "gt", "ge"): + d[op_name] = _make_binary_operator(op_name) + d["and_"] = _make_binary_operator("and") + d["or_"] = _make_binary_operator("or") + return SimpleNamespace(**d) operators = _make_operators() + def init_module(module): - func_type = lc.Type.function(lc.Type.int(64), - [lc.Type.int(64), lc.Type.int(64)]) - module.add_function(func_type, "__gcd64") + func_type = lc.Type.function( + lc.Type.int(64), [lc.Type.int(64), lc.Type.int(64)]) + module.add_function(func_type, "__gcd64") diff --git a/artiq/compiler/lower_time.py b/artiq/compiler/lower_time.py index 78104677e..f4e47d985 100644 --- a/artiq/compiler/lower_time.py +++ b/artiq/compiler/lower_time.py @@ -3,41 +3,48 @@ import ast from artiq.compiler.tools import value_to_ast from artiq.language.core import int64 + def _insert_int64(node): - return ast.copy_location( - ast.Call(func=ast.Name("int64", ast.Load()), - args=[node], - keywords=[], starargs=[], kwargs=[]), node) + return ast.copy_location( + ast.Call(func=ast.Name("int64", ast.Load()), + args=[node], + keywords=[], starargs=[], kwargs=[]), + node) + class _TimeLowerer(ast.NodeTransformer): - def visit_Call(self, node): - if isinstance(node.func, ast.Name) and node.func.id == "now": - return ast.copy_location(ast.Name("now", ast.Load()), node) - else: - self.generic_visit(node) - return node + def visit_Call(self, node): + if isinstance(node.func, ast.Name) and node.func.id == "now": + return ast.copy_location(ast.Name("now", ast.Load()), node) + else: + self.generic_visit(node) + return node + + def visit_Expr(self, node): + self.generic_visit(node) + if (isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Name)): + funcname = node.value.func.id + if funcname == "delay": + return ast.copy_location( + ast.AugAssign(target=ast.Name("now", ast.Store()), + op=ast.Add(), + value=_insert_int64(node.value.args[0])), + node) + elif funcname == "at": + return ast.copy_location( + ast.Assign(targets=[ast.Name("now", ast.Store())], + value=_insert_int64(node.value.args[0])), + node) + else: + return node + else: + return node - def visit_Expr(self, node): - self.generic_visit(node) - if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name): - funcname = node.value.func.id - if funcname == "delay": - return ast.copy_location( - ast.AugAssign(target=ast.Name("now", ast.Store()), op=ast.Add(), - value=_insert_int64(node.value.args[0])), - node) - elif funcname == "at": - return ast.copy_location( - ast.Assign(targets=[ast.Name("now", ast.Store())], - value=_insert_int64(node.value.args[0])), - node) - else: - return node - else: - return node def lower_time(funcdef, initial_time): - _TimeLowerer().visit(funcdef) - funcdef.body.insert(0, ast.copy_location( - ast.Assign(targets=[ast.Name("now", ast.Store())], value=value_to_ast(int64(initial_time))), - funcdef)) + _TimeLowerer().visit(funcdef) + funcdef.body.insert(0, ast.copy_location( + ast.Assign(targets=[ast.Name("now", ast.Store())], + value=value_to_ast(int64(initial_time))), + funcdef)) diff --git a/artiq/compiler/lower_units.py b/artiq/compiler/lower_units.py index f151f5534..2b3001568 100644 --- a/artiq/compiler/lower_units.py +++ b/artiq/compiler/lower_units.py @@ -3,6 +3,7 @@ import ast from artiq.compiler.tools import value_to_ast from artiq.language import units + # TODO: # * track variable and expression dimensions # * raise exception on dimension errors in expressions @@ -11,32 +12,36 @@ from artiq.language import units # e.g. foo = now() + 1*us [...] at(foo) class _UnitsLowerer(ast.NodeTransformer): - def __init__(self, ref_period): - self.ref_period = ref_period - self.in_core_time = False + def __init__(self, ref_period): + self.ref_period = ref_period + self.in_core_time = False + + def visit_Call(self, node): + fn = node.func.id + if fn in ("delay", "at"): + old_in_core_time = self.in_core_time + self.in_core_time = True + self.generic_visit(node) + self.in_core_time = old_in_core_time + elif fn == "Quantity": + if self.in_core_time: + if node.args[1].id == "microcycle_units": + node = node.args[0] + else: + node = ast.copy_location( + ast.BinOp(left=node.args[0], + op=ast.Div(), + right=value_to_ast(self.ref_period)), + node) + else: + node = node.args[0] + else: + self.generic_visit(node) + return node - def visit_Call(self, node): - fn = node.func.id - if fn in ("delay", "at"): - old_in_core_time = self.in_core_time - self.in_core_time = True - self.generic_visit(node) - self.in_core_time = old_in_core_time - elif fn == "Quantity": - if self.in_core_time: - if node.args[1].id == "microcycle_units": - node = node.args[0] - else: - node = ast.copy_location( - ast.BinOp(left=node.args[0], op=ast.Div(), right=value_to_ast(self.ref_period)), - node) - else: - node = node.args[0] - else: - self.generic_visit(node) - return node def lower_units(funcdef, ref_period): - if not isinstance(ref_period, units.Quantity) or ref_period.unit is not units.s_unit: - raise units.DimensionError("Reference period not expressed in seconds") - _UnitsLowerer(ref_period.amount).visit(funcdef) + if (not isinstance(ref_period, units.Quantity) + or ref_period.unit is not units.s_unit): + raise units.DimensionError("Reference period not expressed in seconds") + _UnitsLowerer(ref_period.amount).visit(funcdef) diff --git a/artiq/compiler/tools.py b/artiq/compiler/tools.py index 49cd16743..ae3942114 100644 --- a/artiq/compiler/tools.py +++ b/artiq/compiler/tools.py @@ -4,60 +4,67 @@ from fractions import Fraction from artiq.language import core as core_language from artiq.language import units + def eval_ast(expr, symdict=dict()): - if not isinstance(expr, ast.Expression): - expr = ast.copy_location(ast.Expression(expr), expr) - ast.fix_missing_locations(expr) - code = compile(expr, "", "eval") - return eval(code, symdict) + if not isinstance(expr, ast.Expression): + expr = ast.copy_location(ast.Expression(expr), expr) + ast.fix_missing_locations(expr) + code = compile(expr, "", "eval") + return eval(code, symdict) + def value_to_ast(value): - if isinstance(value, core_language.int64): # must be before int - return ast.Call( - func=ast.Name("int64", ast.Load()), - args=[ast.Num(int(value))], - keywords=[], starargs=None, kwargs=None) - elif isinstance(value, int): - return ast.Num(value) - elif isinstance(value, Fraction): - return ast.Call(func=ast.Name("Fraction", ast.Load()), - args=[ast.Num(value.numerator), ast.Num(value.denominator)], - keywords=[], starargs=None, kwargs=None) - elif isinstance(value, str): - return ast.Str(value) - else: - for kg in core_language.kernel_globals: - if value is getattr(core_language, kg): - return ast.Name(kg, ast.Load()) - if isinstance(value, units.Quantity): - return ast.Call( - func=ast.Name("Quantity", ast.Load()), - args=[value_to_ast(value.amount), ast.Name(value.unit.name+"_unit", ast.Load())], - keywords=[], starargs=None, kwargs=None) - return None + if isinstance(value, core_language.int64): # must be before int + return ast.Call( + func=ast.Name("int64", ast.Load()), + args=[ast.Num(int(value))], + keywords=[], starargs=None, kwargs=None) + elif isinstance(value, int): + return ast.Num(value) + elif isinstance(value, Fraction): + return ast.Call( + func=ast.Name("Fraction", ast.Load()), + args=[ast.Num(value.numerator), ast.Num(value.denominator)], + keywords=[], starargs=None, kwargs=None) + elif isinstance(value, str): + return ast.Str(value) + else: + for kg in core_language.kernel_globals: + if value is getattr(core_language, kg): + return ast.Name(kg, ast.Load()) + if isinstance(value, units.Quantity): + return ast.Call( + func=ast.Name("Quantity", ast.Load()), + args=[value_to_ast(value.amount), + ast.Name(value.unit.name+"_unit", ast.Load())], + keywords=[], starargs=None, kwargs=None) + return None + class NotConstant(Exception): - pass + pass + def eval_constant(node): - if isinstance(node, ast.Num): - return node.n - elif isinstance(node, ast.Str): - return node.s - elif isinstance(node, ast.Call): - funcname = node.func.id - if funcname == "Fraction": - numerator, denominator = eval_constant(node.args[0]), eval_constant(node.args[1]) - return Fraction(numerator, denominator) - elif funcname == "Quantity": - amount, unit = node.args - amount = eval_constant(amount) - try: - unit = getattr(units, unit.id) - except: - raise NotConstant - return units.Quantity(amount, unit) - else: - raise NotConstant - else: - raise NotConstant + if isinstance(node, ast.Num): + return node.n + elif isinstance(node, ast.Str): + return node.s + elif isinstance(node, ast.Call): + funcname = node.func.id + if funcname == "Fraction": + numerator = eval_constant(node.args[0]) + denominator = eval_constant(node.args[1]) + return Fraction(numerator, denominator) + elif funcname == "Quantity": + amount, unit = node.args + amount = eval_constant(amount) + try: + unit = getattr(units, unit.id) + except: + raise NotConstant + return units.Quantity(amount, unit) + else: + raise NotConstant + else: + raise NotConstant diff --git a/artiq/compiler/unparse.py b/artiq/compiler/unparse.py index be280132f..5d570dbf5 100644 --- a/artiq/compiler/unparse.py +++ b/artiq/compiler/unparse.py @@ -1,564 +1,596 @@ import sys import ast -import os + # Large float and imaginary literals get turned into infinities in the AST. # We unparse those infinities to INFSTR. INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) + def interleave(inter, f, seq): - """Call f on each item in seq, calling inter() in between. - """ - seq = iter(seq) - try: - f(next(seq)) - except StopIteration: - pass - else: - for x in seq: - inter() - f(x) + """Call f on each item in seq, calling inter() in between. + """ + seq = iter(seq) + try: + f(next(seq)) + except StopIteration: + pass + else: + for x in seq: + inter() + f(x) + class Unparser: - """Methods in this class recursively traverse an AST and - output source code for the abstract syntax; original formatting - is disregarded. """ - - def __init__(self, tree, file = sys.stdout): - """Unparser(tree, file=sys.stdout) -> None. - Print the source for tree to file.""" - self.f = file - self._indent = 0 - self.dispatch(tree) - print("", file=self.f) - self.f.flush() - - def fill(self, text = ""): - "Indent a piece of text, according to the current indentation level" - self.f.write("\n"+" "*self._indent + text) - - def write(self, text): - "Append a piece of text to the current line." - self.f.write(text) - - def enter(self): - "Print ':', and increase the indentation." - self.write(":") - self._indent += 1 - - def leave(self): - "Decrease the indentation level." - self._indent -= 1 - - def dispatch(self, tree): - "Dispatcher function, dispatching tree type T to method _T." - if isinstance(tree, list): - for t in tree: - self.dispatch(t) - return - meth = getattr(self, "_"+tree.__class__.__name__) - meth(tree) - - - ############### Unparsing methods ###################### - # There should be one method per concrete grammar type # - # Constructors should be grouped by sum type. Ideally, # - # this would follow the order in the grammar, but # - # currently doesn't. # - ######################################################## - - def _Module(self, tree): - for stmt in tree.body: - self.dispatch(stmt) - - # stmt - def _Expr(self, tree): - self.fill() - self.dispatch(tree.value) - - def _Import(self, t): - self.fill("import ") - interleave(lambda: self.write(", "), self.dispatch, t.names) - - def _ImportFrom(self, t): - self.fill("from ") - self.write("." * t.level) - if t.module: - self.write(t.module) - self.write(" import ") - interleave(lambda: self.write(", "), self.dispatch, t.names) - - def _Assign(self, t): - self.fill() - for target in t.targets: - self.dispatch(target) - self.write(" = ") - self.dispatch(t.value) - - def _AugAssign(self, t): - self.fill() - self.dispatch(t.target) - self.write(" "+self.binop[t.op.__class__.__name__]+"= ") - self.dispatch(t.value) - - def _Return(self, t): - self.fill("return") - if t.value: - self.write(" ") - self.dispatch(t.value) - - def _Pass(self, t): - self.fill("pass") - - def _Break(self, t): - self.fill("break") - - def _Continue(self, t): - self.fill("continue") - - def _Delete(self, t): - self.fill("del ") - interleave(lambda: self.write(", "), self.dispatch, t.targets) - - def _Assert(self, t): - self.fill("assert ") - self.dispatch(t.test) - if t.msg: - self.write(", ") - self.dispatch(t.msg) - - def _Global(self, t): - self.fill("global ") - interleave(lambda: self.write(", "), self.write, t.names) - - def _Nonlocal(self, t): - self.fill("nonlocal ") - interleave(lambda: self.write(", "), self.write, t.names) - - def _Yield(self, t): - self.write("(") - self.write("yield") - if t.value: - self.write(" ") - self.dispatch(t.value) - self.write(")") - - def _YieldFrom(self, t): - self.write("(") - self.write("yield from") - if t.value: - self.write(" ") - self.dispatch(t.value) - self.write(")") - - def _Raise(self, t): - self.fill("raise") - if not t.exc: - assert not t.cause - return - self.write(" ") - self.dispatch(t.exc) - if t.cause: - self.write(" from ") - self.dispatch(t.cause) - - def _Try(self, t): - self.fill("try") - self.enter() - self.dispatch(t.body) - self.leave() - for ex in t.handlers: - self.dispatch(ex) - if t.orelse: - self.fill("else") - self.enter() - self.dispatch(t.orelse) - self.leave() - if t.finalbody: - self.fill("finally") - self.enter() - self.dispatch(t.finalbody) - self.leave() - - def _ExceptHandler(self, t): - self.fill("except") - if t.type: - self.write(" ") - self.dispatch(t.type) - if t.name: - self.write(" as ") - self.write(t.name) - self.enter() - self.dispatch(t.body) - self.leave() - - def _ClassDef(self, t): - self.write("\n") - for deco in t.decorator_list: - self.fill("@") - self.dispatch(deco) - self.fill("class "+t.name) - self.write("(") - comma = False - for e in t.bases: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) - for e in t.keywords: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) - if t.starargs: - if comma: self.write(", ") - else: comma = True - self.write("*") - self.dispatch(t.starargs) - if t.kwargs: - if comma: self.write(", ") - else: comma = True - self.write("**") - self.dispatch(t.kwargs) - self.write(")") - - self.enter() - self.dispatch(t.body) - self.leave() - - def _FunctionDef(self, t): - self.write("\n") - for deco in t.decorator_list: - self.fill("@") - self.dispatch(deco) - self.fill("def "+t.name + "(") - self.dispatch(t.args) - self.write(")") - if t.returns: - self.write(" -> ") - self.dispatch(t.returns) - self.enter() - self.dispatch(t.body) - self.leave() - - def _For(self, t): - self.fill("for ") - self.dispatch(t.target) - self.write(" in ") - self.dispatch(t.iter) - self.enter() - self.dispatch(t.body) - self.leave() - if t.orelse: - self.fill("else") - self.enter() - self.dispatch(t.orelse) - self.leave() - - def _If(self, t): - self.fill("if ") - self.dispatch(t.test) - self.enter() - self.dispatch(t.body) - self.leave() - # collapse nested ifs into equivalent elifs. - while (t.orelse and len(t.orelse) == 1 and - isinstance(t.orelse[0], ast.If)): - t = t.orelse[0] - self.fill("elif ") - self.dispatch(t.test) - self.enter() - self.dispatch(t.body) - self.leave() - # final else - if t.orelse: - self.fill("else") - self.enter() - self.dispatch(t.orelse) - self.leave() - - def _While(self, t): - self.fill("while ") - self.dispatch(t.test) - self.enter() - self.dispatch(t.body) - self.leave() - if t.orelse: - self.fill("else") - self.enter() - self.dispatch(t.orelse) - self.leave() - - def _With(self, t): - self.fill("with ") - interleave(lambda: self.write(", "), self.dispatch, t.items) - self.enter() - self.dispatch(t.body) - self.leave() - - # expr - def _Bytes(self, t): - self.write(repr(t.s)) - - def _Str(self, tree): - self.write(repr(tree.s)) - - def _Name(self, t): - self.write(t.id) - - def _NameConstant(self, t): - self.write(repr(t.value)) - - def _Num(self, t): - # Substitute overflowing decimal literal for AST infinities. - self.write(repr(t.n).replace("inf", INFSTR)) - - def _List(self, t): - self.write("[") - interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write("]") - - def _ListComp(self, t): - self.write("[") - self.dispatch(t.elt) - for gen in t.generators: - self.dispatch(gen) - self.write("]") - - def _GeneratorExp(self, t): - self.write("(") - self.dispatch(t.elt) - for gen in t.generators: - self.dispatch(gen) - self.write(")") - - def _SetComp(self, t): - self.write("{") - self.dispatch(t.elt) - for gen in t.generators: - self.dispatch(gen) - self.write("}") - - def _DictComp(self, t): - self.write("{") - self.dispatch(t.key) - self.write(": ") - self.dispatch(t.value) - for gen in t.generators: - self.dispatch(gen) - self.write("}") - - def _comprehension(self, t): - self.write(" for ") - self.dispatch(t.target) - self.write(" in ") - self.dispatch(t.iter) - for if_clause in t.ifs: - self.write(" if ") - self.dispatch(if_clause) - - def _IfExp(self, t): - self.write("(") - self.dispatch(t.body) - self.write(" if ") - self.dispatch(t.test) - self.write(" else ") - self.dispatch(t.orelse) - self.write(")") - - def _Set(self, t): - assert(t.elts) # should be at least one element - self.write("{") - interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write("}") - - def _Dict(self, t): - self.write("{") - def write_pair(pair): - (k, v) = pair - self.dispatch(k) - self.write(": ") - self.dispatch(v) - interleave(lambda: self.write(", "), write_pair, zip(t.keys, t.values)) - self.write("}") - - def _Tuple(self, t): - self.write("(") - if len(t.elts) == 1: - (elt,) = t.elts - self.dispatch(elt) - self.write(",") - else: - interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write(")") - - unop = {"Invert":"~", "Not": "not", "UAdd":"+", "USub":"-"} - def _UnaryOp(self, t): - self.write("(") - self.write(self.unop[t.op.__class__.__name__]) - self.write(" ") - self.dispatch(t.operand) - self.write(")") - - binop = { "Add":"+", "Sub":"-", "Mult":"*", "Div":"/", "Mod":"%", - "LShift":"<<", "RShift":">>", "BitOr":"|", "BitXor":"^", "BitAnd":"&", - "FloorDiv":"//", "Pow": "**"} - def _BinOp(self, t): - self.write("(") - self.dispatch(t.left) - self.write(" " + self.binop[t.op.__class__.__name__] + " ") - self.dispatch(t.right) - self.write(")") - - cmpops = {"Eq":"==", "NotEq":"!=", "Lt":"<", "LtE":"<=", "Gt":">", "GtE":">=", - "Is":"is", "IsNot":"is not", "In":"in", "NotIn":"not in"} - def _Compare(self, t): - self.write("(") - self.dispatch(t.left) - for o, e in zip(t.ops, t.comparators): - self.write(" " + self.cmpops[o.__class__.__name__] + " ") - self.dispatch(e) - self.write(")") - - boolops = {ast.And: 'and', ast.Or: 'or'} - def _BoolOp(self, t): - self.write("(") - s = " %s " % self.boolops[t.op.__class__] - interleave(lambda: self.write(s), self.dispatch, t.values) - self.write(")") - - def _Attribute(self,t): - self.dispatch(t.value) - # Special case: 3.__abs__() is a syntax error, so if t.value - # is an integer literal then we need to either parenthesize - # it or add an extra space to get 3 .__abs__(). - if isinstance(t.value, ast.Num) and isinstance(t.value.n, int): - self.write(" ") - self.write(".") - self.write(t.attr) - - def _Call(self, t): - self.dispatch(t.func) - self.write("(") - comma = False - for e in t.args: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) - for e in t.keywords: - if comma: self.write(", ") - else: comma = True - self.dispatch(e) - if t.starargs: - if comma: self.write(", ") - else: comma = True - self.write("*") - self.dispatch(t.starargs) - if t.kwargs: - if comma: self.write(", ") - else: comma = True - self.write("**") - self.dispatch(t.kwargs) - self.write(")") - - def _Subscript(self, t): - self.dispatch(t.value) - self.write("[") - self.dispatch(t.slice) - self.write("]") - - def _Starred(self, t): - self.write("*") - self.dispatch(t.value) - - # slice - def _Ellipsis(self, t): - self.write("...") - - def _Index(self, t): - self.dispatch(t.value) - - def _Slice(self, t): - if t.lower: - self.dispatch(t.lower) - self.write(":") - if t.upper: - self.dispatch(t.upper) - if t.step: - self.write(":") - self.dispatch(t.step) - - def _ExtSlice(self, t): - interleave(lambda: self.write(', '), self.dispatch, t.dims) - - # argument - def _arg(self, t): - self.write(t.arg) - if t.annotation: - self.write(": ") - self.dispatch(t.annotation) - - # others - def _arguments(self, t): - first = True - # normal arguments - defaults = [None] * (len(t.args) - len(t.defaults)) + t.defaults - for a, d in zip(t.args, defaults): - if first:first = False - else: self.write(", ") - self.dispatch(a) - if d: - self.write("=") - self.dispatch(d) - - # varargs, or bare '*' if no varargs but keyword-only arguments present - if t.vararg or t.kwonlyargs: - if first:first = False - else: self.write(", ") - self.write("*") - if t.vararg: - self.write(t.vararg.arg) - if t.vararg.annotation: - self.write(": ") - self.dispatch(t.vararg.annotation) - - # keyword-only arguments - if t.kwonlyargs: - for a, d in zip(t.kwonlyargs, t.kw_defaults): - if first:first = False - else: self.write(", ") - self.dispatch(a), - if d: - self.write("=") - self.dispatch(d) - - # kwargs - if t.kwarg: - if first:first = False - else: self.write(", ") - self.write("**"+t.kwarg.arg) - if t.kwarg.annotation: - self.write(": ") - self.dispatch(t.kwarg.annotation) - - def _keyword(self, t): - self.write(t.arg) - self.write("=") - self.dispatch(t.value) - - def _Lambda(self, t): - self.write("(") - self.write("lambda ") - self.dispatch(t.args) - self.write(": ") - self.dispatch(t.body) - self.write(")") - - def _alias(self, t): - self.write(t.name) - if t.asname: - self.write(" as "+t.asname) - - def _withitem(self, t): - self.dispatch(t.context_expr) - if t.optional_vars: - self.write(" as ") - self.dispatch(t.optional_vars) + """Methods in this class recursively traverse an AST and + output source code for the abstract syntax; original formatting + is disregarded. """ + + def __init__(self, tree, file=sys.stdout): + """Unparser(tree, file=sys.stdout) -> None. + Print the source for tree to file.""" + self.f = file + self._indent = 0 + self.dispatch(tree) + print("", file=self.f) + self.f.flush() + + def fill(self, text=""): + "Indent a piece of text, according to the current indentation level" + self.f.write("\n"+" "*self._indent + text) + + def write(self, text): + "Append a piece of text to the current line." + self.f.write(text) + + def enter(self): + "Print ':', and increase the indentation." + self.write(":") + self._indent += 1 + + def leave(self): + "Decrease the indentation level." + self._indent -= 1 + + def dispatch(self, tree): + "Dispatcher function, dispatching tree type T to method _T." + if isinstance(tree, list): + for t in tree: + self.dispatch(t) + return + meth = getattr(self, "_"+tree.__class__.__name__) + meth(tree) + + # Unparsing methods + # + # There should be one method per concrete grammar type + # Constructors should be grouped by sum type. Ideally, + # this would follow the order in the grammar, but + # currently doesn't. + + def _Module(self, tree): + for stmt in tree.body: + self.dispatch(stmt) + + # stmt + def _Expr(self, tree): + self.fill() + self.dispatch(tree.value) + + def _Import(self, t): + self.fill("import ") + interleave(lambda: self.write(", "), self.dispatch, t.names) + + def _ImportFrom(self, t): + self.fill("from ") + self.write("." * t.level) + if t.module: + self.write(t.module) + self.write(" import ") + interleave(lambda: self.write(", "), self.dispatch, t.names) + + def _Assign(self, t): + self.fill() + for target in t.targets: + self.dispatch(target) + self.write(" = ") + self.dispatch(t.value) + + def _AugAssign(self, t): + self.fill() + self.dispatch(t.target) + self.write(" "+self.binop[t.op.__class__.__name__]+"= ") + self.dispatch(t.value) + + def _Return(self, t): + self.fill("return") + if t.value: + self.write(" ") + self.dispatch(t.value) + + def _Pass(self, t): + self.fill("pass") + + def _Break(self, t): + self.fill("break") + + def _Continue(self, t): + self.fill("continue") + + def _Delete(self, t): + self.fill("del ") + interleave(lambda: self.write(", "), self.dispatch, t.targets) + + def _Assert(self, t): + self.fill("assert ") + self.dispatch(t.test) + if t.msg: + self.write(", ") + self.dispatch(t.msg) + + def _Global(self, t): + self.fill("global ") + interleave(lambda: self.write(", "), self.write, t.names) + + def _Nonlocal(self, t): + self.fill("nonlocal ") + interleave(lambda: self.write(", "), self.write, t.names) + + def _Yield(self, t): + self.write("(") + self.write("yield") + if t.value: + self.write(" ") + self.dispatch(t.value) + self.write(")") + + def _YieldFrom(self, t): + self.write("(") + self.write("yield from") + if t.value: + self.write(" ") + self.dispatch(t.value) + self.write(")") + + def _Raise(self, t): + self.fill("raise") + if not t.exc: + assert not t.cause + return + self.write(" ") + self.dispatch(t.exc) + if t.cause: + self.write(" from ") + self.dispatch(t.cause) + + def _Try(self, t): + self.fill("try") + self.enter() + self.dispatch(t.body) + self.leave() + for ex in t.handlers: + self.dispatch(ex) + if t.orelse: + self.fill("else") + self.enter() + self.dispatch(t.orelse) + self.leave() + if t.finalbody: + self.fill("finally") + self.enter() + self.dispatch(t.finalbody) + self.leave() + + def _ExceptHandler(self, t): + self.fill("except") + if t.type: + self.write(" ") + self.dispatch(t.type) + if t.name: + self.write(" as ") + self.write(t.name) + self.enter() + self.dispatch(t.body) + self.leave() + + def _ClassDef(self, t): + self.write("\n") + for deco in t.decorator_list: + self.fill("@") + self.dispatch(deco) + self.fill("class "+t.name) + self.write("(") + comma = False + for e in t.bases: + if comma: + self.write(", ") + else: + comma = True + self.dispatch(e) + for e in t.keywords: + if comma: + self.write(", ") + else: + comma = True + self.dispatch(e) + if t.starargs: + if comma: + self.write(", ") + else: + comma = True + self.write("*") + self.dispatch(t.starargs) + if t.kwargs: + if comma: + self.write(", ") + else: + comma = True + self.write("**") + self.dispatch(t.kwargs) + self.write(")") + + self.enter() + self.dispatch(t.body) + self.leave() + + def _FunctionDef(self, t): + self.write("\n") + for deco in t.decorator_list: + self.fill("@") + self.dispatch(deco) + self.fill("def "+t.name + "(") + self.dispatch(t.args) + self.write(")") + if t.returns: + self.write(" -> ") + self.dispatch(t.returns) + self.enter() + self.dispatch(t.body) + self.leave() + + def _For(self, t): + self.fill("for ") + self.dispatch(t.target) + self.write(" in ") + self.dispatch(t.iter) + self.enter() + self.dispatch(t.body) + self.leave() + if t.orelse: + self.fill("else") + self.enter() + self.dispatch(t.orelse) + self.leave() + + def _If(self, t): + self.fill("if ") + self.dispatch(t.test) + self.enter() + self.dispatch(t.body) + self.leave() + # collapse nested ifs into equivalent elifs. + while (t.orelse and len(t.orelse) == 1 and + isinstance(t.orelse[0], ast.If)): + t = t.orelse[0] + self.fill("elif ") + self.dispatch(t.test) + self.enter() + self.dispatch(t.body) + self.leave() + # final else + if t.orelse: + self.fill("else") + self.enter() + self.dispatch(t.orelse) + self.leave() + + def _While(self, t): + self.fill("while ") + self.dispatch(t.test) + self.enter() + self.dispatch(t.body) + self.leave() + if t.orelse: + self.fill("else") + self.enter() + self.dispatch(t.orelse) + self.leave() + + def _With(self, t): + self.fill("with ") + interleave(lambda: self.write(", "), self.dispatch, t.items) + self.enter() + self.dispatch(t.body) + self.leave() + + # expr + def _Bytes(self, t): + self.write(repr(t.s)) + + def _Str(self, tree): + self.write(repr(tree.s)) + + def _Name(self, t): + self.write(t.id) + + def _NameConstant(self, t): + self.write(repr(t.value)) + + def _Num(self, t): + # Substitute overflowing decimal literal for AST infinities. + self.write(repr(t.n).replace("inf", INFSTR)) + + def _List(self, t): + self.write("[") + interleave(lambda: self.write(", "), self.dispatch, t.elts) + self.write("]") + + def _ListComp(self, t): + self.write("[") + self.dispatch(t.elt) + for gen in t.generators: + self.dispatch(gen) + self.write("]") + + def _GeneratorExp(self, t): + self.write("(") + self.dispatch(t.elt) + for gen in t.generators: + self.dispatch(gen) + self.write(")") + + def _SetComp(self, t): + self.write("{") + self.dispatch(t.elt) + for gen in t.generators: + self.dispatch(gen) + self.write("}") + + def _DictComp(self, t): + self.write("{") + self.dispatch(t.key) + self.write(": ") + self.dispatch(t.value) + for gen in t.generators: + self.dispatch(gen) + self.write("}") + + def _comprehension(self, t): + self.write(" for ") + self.dispatch(t.target) + self.write(" in ") + self.dispatch(t.iter) + for if_clause in t.ifs: + self.write(" if ") + self.dispatch(if_clause) + + def _IfExp(self, t): + self.write("(") + self.dispatch(t.body) + self.write(" if ") + self.dispatch(t.test) + self.write(" else ") + self.dispatch(t.orelse) + self.write(")") + + def _Set(self, t): + assert(t.elts) # should be at least one element + self.write("{") + interleave(lambda: self.write(", "), self.dispatch, t.elts) + self.write("}") + + def _Dict(self, t): + self.write("{") + + def write_pair(pair): + (k, v) = pair + self.dispatch(k) + self.write(": ") + self.dispatch(v) + interleave(lambda: self.write(", "), write_pair, zip(t.keys, t.values)) + self.write("}") + + def _Tuple(self, t): + self.write("(") + if len(t.elts) == 1: + (elt,) = t.elts + self.dispatch(elt) + self.write(",") + else: + interleave(lambda: self.write(", "), self.dispatch, t.elts) + self.write(")") + + unop = {"Invert": "~", "Not": "not", "UAdd": "+", "USub": "-"} + + def _UnaryOp(self, t): + self.write("(") + self.write(self.unop[t.op.__class__.__name__]) + self.write(" ") + self.dispatch(t.operand) + self.write(")") + + binop = {"Add": "+", "Sub": "-", "Mult": "*", "Div": "/", "Mod": "%", + "LShift": "<<", "RShift": ">>", + "BitOr": "|", "BitXor": "^", "BitAnd": "&", + "FloorDiv": "//", "Pow": "**"} + + def _BinOp(self, t): + self.write("(") + self.dispatch(t.left) + self.write(" " + self.binop[t.op.__class__.__name__] + " ") + self.dispatch(t.right) + self.write(")") + + cmpops = {"Eq": "==", "NotEq": "!=", + "Lt": "<", "LtE": "<=", "Gt": ">", "GtE": ">=", + "Is": "is", "IsNot": "is not", "In": "in", "NotIn": "not in"} + + def _Compare(self, t): + self.write("(") + self.dispatch(t.left) + for o, e in zip(t.ops, t.comparators): + self.write(" " + self.cmpops[o.__class__.__name__] + " ") + self.dispatch(e) + self.write(")") + + boolops = {ast.And: "and", ast.Or: "or"} + + def _BoolOp(self, t): + self.write("(") + s = " %s " % self.boolops[t.op.__class__] + interleave(lambda: self.write(s), self.dispatch, t.values) + self.write(")") + + def _Attribute(self, t): + self.dispatch(t.value) + # Special case: 3.__abs__() is a syntax error, so if t.value + # is an integer literal then we need to either parenthesize + # it or add an extra space to get 3 .__abs__(). + if isinstance(t.value, ast.Num) and isinstance(t.value.n, int): + self.write(" ") + self.write(".") + self.write(t.attr) + + def _Call(self, t): + self.dispatch(t.func) + self.write("(") + comma = False + for e in t.args: + if comma: + self.write(", ") + else: + comma = True + self.dispatch(e) + for e in t.keywords: + if comma: + self.write(", ") + else: + comma = True + self.dispatch(e) + if t.starargs: + if comma: + self.write(", ") + else: + comma = True + self.write("*") + self.dispatch(t.starargs) + if t.kwargs: + if comma: + self.write(", ") + else: + comma = True + self.write("**") + self.dispatch(t.kwargs) + self.write(")") + + def _Subscript(self, t): + self.dispatch(t.value) + self.write("[") + self.dispatch(t.slice) + self.write("]") + + def _Starred(self, t): + self.write("*") + self.dispatch(t.value) + + # slice + def _Ellipsis(self, t): + self.write("...") + + def _Index(self, t): + self.dispatch(t.value) + + def _Slice(self, t): + if t.lower: + self.dispatch(t.lower) + self.write(":") + if t.upper: + self.dispatch(t.upper) + if t.step: + self.write(":") + self.dispatch(t.step) + + def _ExtSlice(self, t): + interleave(lambda: self.write(', '), self.dispatch, t.dims) + + # argument + def _arg(self, t): + self.write(t.arg) + if t.annotation: + self.write(": ") + self.dispatch(t.annotation) + + # others + def _arguments(self, t): + first = True + # normal arguments + defaults = [None] * (len(t.args) - len(t.defaults)) + t.defaults + for a, d in zip(t.args, defaults): + if first: + first = False + else: + self.write(", ") + self.dispatch(a) + if d: + self.write("=") + self.dispatch(d) + + # varargs, or bare '*' if no varargs but keyword-only arguments present + if t.vararg or t.kwonlyargs: + if first: + first = False + else: + self.write(", ") + self.write("*") + if t.vararg: + self.write(t.vararg.arg) + if t.vararg.annotation: + self.write(": ") + self.dispatch(t.vararg.annotation) + + # keyword-only arguments + if t.kwonlyargs: + for a, d in zip(t.kwonlyargs, t.kw_defaults): + if first: + first = False + else: + self.write(", ") + self.dispatch(a), + if d: + self.write("=") + self.dispatch(d) + + # kwargs + if t.kwarg: + if first: + first = False + else: + self.write(", ") + self.write("**"+t.kwarg.arg) + if t.kwarg.annotation: + self.write(": ") + self.dispatch(t.kwarg.annotation) + + def _keyword(self, t): + self.write(t.arg) + self.write("=") + self.dispatch(t.value) + + def _Lambda(self, t): + self.write("(") + self.write("lambda ") + self.dispatch(t.args) + self.write(": ") + self.dispatch(t.body) + self.write(")") + + def _alias(self, t): + self.write(t.name) + if t.asname: + self.write(" as "+t.asname) + + def _withitem(self, t): + self.dispatch(t.context_expr) + if t.optional_vars: + self.write(" as ") + self.dispatch(t.optional_vars) diff --git a/artiq/compiler/unroll_loops.py b/artiq/compiler/unroll_loops.py index 6888ecf9b..632385f37 100644 --- a/artiq/compiler/unroll_loops.py +++ b/artiq/compiler/unroll_loops.py @@ -2,46 +2,51 @@ import ast from artiq.compiler.tools import eval_ast, value_to_ast + def _count_stmts(node): - if isinstance(node, (ast.For, ast.While, ast.If)): - return 1 + _count_stmts(node.body) + _count_stmts(node.orelse) - elif isinstance(node, ast.With): - return 1 + _count_stmts(node.body) - elif isinstance(node, list): - return sum(map(_count_stmts, node)) - else: - return 1 + if isinstance(node, (ast.For, ast.While, ast.If)): + return 1 + _count_stmts(node.body) + _count_stmts(node.orelse) + elif isinstance(node, ast.With): + return 1 + _count_stmts(node.body) + elif isinstance(node, list): + return sum(map(_count_stmts, node)) + else: + return 1 + class _LoopUnroller(ast.NodeTransformer): - def __init__(self, limit): - self.limit = limit + def __init__(self, limit): + self.limit = limit + + def visit_For(self, node): + self.generic_visit(node) + try: + it = eval_ast(node.iter) + except: + return node + l_it = len(it) + if l_it: + n = l_it*_count_stmts(node.body) + if n < self.limit: + replacement = [] + for i in it: + if not isinstance(i, int): + replacement = None + break + replacement.append(ast.copy_location( + ast.Assign(targets=[node.target], + value=value_to_ast(i)), + node)) + replacement += node.body + if replacement is not None: + return replacement + else: + return node + else: + return node + else: + return node.orelse - def visit_For(self, node): - self.generic_visit(node) - try: - it = eval_ast(node.iter) - except: - return node - l_it = len(it) - if l_it: - n = l_it*_count_stmts(node.body) - if n < self.limit: - replacement = [] - for i in it: - if not isinstance(i, int): - replacement = None - break - replacement.append(ast.copy_location( - ast.Assign(targets=[node.target], value=value_to_ast(i)), node)) - replacement += node.body - if replacement is not None: - return replacement - else: - return node - else: - return node - else: - return node.orelse def unroll_loops(node, limit): - _LoopUnroller(limit).visit(node) + _LoopUnroller(limit).visit(node) diff --git a/artiq/devices/core.py b/artiq/devices/core.py index e829160a4..6691a7e7f 100644 --- a/artiq/devices/core.py +++ b/artiq/devices/core.py @@ -6,22 +6,23 @@ from artiq.compiler.interleave import interleave from artiq.compiler.lower_time import lower_time from artiq.compiler.ir import get_runtime_binary + class Core: - def __init__(self, core_com, runtime_env=None): - if runtime_env is None: - runtime_env = core_com.get_runtime_env() - self.runtime_env = runtime_env - self.core_com = core_com + def __init__(self, core_com, runtime_env=None): + if runtime_env is None: + runtime_env = core_com.get_runtime_env() + self.runtime_env = runtime_env + self.core_com = core_com - def run(self, k_function, k_args, k_kwargs): - funcdef, rpc_map = inline(self, k_function, k_args, k_kwargs) - lower_units(funcdef, self.runtime_env.ref_period) - fold_constants(funcdef) - unroll_loops(funcdef, 50) - interleave(funcdef) - lower_time(funcdef, getattr(self.runtime_env, "initial_time", 0)) - fold_constants(funcdef) + def run(self, k_function, k_args, k_kwargs): + funcdef, rpc_map = inline(self, k_function, k_args, k_kwargs) + lower_units(funcdef, self.runtime_env.ref_period) + fold_constants(funcdef) + unroll_loops(funcdef, 50) + interleave(funcdef) + lower_time(funcdef, getattr(self.runtime_env, "initial_time", 0)) + fold_constants(funcdef) - binary = get_runtime_binary(self.runtime_env, funcdef) - self.core_com.run(binary) - self.core_com.serve(rpc_map) + binary = get_runtime_binary(self.runtime_env, funcdef) + self.core_com.run(binary) + self.core_com.serve(rpc_map) diff --git a/artiq/devices/corecom_dummy.py b/artiq/devices/corecom_dummy.py index c06d3fe6b..4d3cb5aef 100644 --- a/artiq/devices/corecom_dummy.py +++ b/artiq/devices/corecom_dummy.py @@ -3,26 +3,28 @@ from operator import itemgetter from artiq.devices.runtime import LinkInterface from artiq.language.units import ns -class _RuntimeEnvironment(LinkInterface): - def __init__(self, ref_period): - self.ref_period = ref_period - def emit_object(self): - return str(self.module) +class _RuntimeEnvironment(LinkInterface): + def __init__(self, ref_period): + self.ref_period = ref_period + + def emit_object(self): + return str(self.module) + class CoreCom: - def get_runtime_env(self): - return _RuntimeEnvironment(10*ns) + def get_runtime_env(self): + return _RuntimeEnvironment(10*ns) - def run(self, kcode): - print("================") - print(" LLVM IR") - print("================") - print(kcode) + def run(self, kcode): + print("================") + print(" LLVM IR") + print("================") + print(kcode) - def serve(self, rpc_map): - print("================") - print(" RPC map") - print("================") - for k, v in sorted(rpc_map.items(), key=itemgetter(0)): - print(str(k)+" -> "+str(v)) + def serve(self, rpc_map): + print("================") + print(" RPC map") + print("================") + for k, v in sorted(rpc_map.items(), key=itemgetter(0)): + print(str(k)+" -> "+str(v)) diff --git a/artiq/devices/corecom_serial.py b/artiq/devices/corecom_serial.py index 1b45b3fe6..3a62690ea 100644 --- a/artiq/devices/corecom_serial.py +++ b/artiq/devices/corecom_serial.py @@ -1,111 +1,126 @@ -import os, termios, struct, zlib +import os +import termios +import struct +import zlib from enum import Enum from artiq.language import units from artiq.devices.runtime import Environment + class UnsupportedDevice(Exception): - pass + pass + class _MsgType(Enum): - REQUEST_IDENT = 0x01 - LOAD_KERNEL = 0x02 - KERNEL_FINISHED = 0x03 - RPC_REQUEST = 0x04 + REQUEST_IDENT = 0x01 + LOAD_KERNEL = 0x02 + KERNEL_FINISHED = 0x03 + RPC_REQUEST = 0x04 + def _write_exactly(f, data): - remaining = len(data) - pos = 0 - while remaining: - written = f.write(data[pos:]) - remaining -= written - pos += written + remaining = len(data) + pos = 0 + while remaining: + written = f.write(data[pos:]) + remaining -= written + pos += written + def _read_exactly(f, n): - r = bytes() - while(len(r) < n): - r += f.read(n - len(r)) - return r + r = bytes() + while(len(r) < n): + r += f.read(n - len(r)) + return r + class CoreCom: - def __init__(self, dev="/dev/ttyUSB1", baud=115200): - self._fd = os.open(dev, os.O_RDWR | os.O_NOCTTY) - self.port = os.fdopen(self._fd, "r+b", buffering=0) - iflag, oflag, cflag, lflag, ispeed, ospeed, cc = \ - termios.tcgetattr(self._fd) - iflag = termios.IGNBRK | termios.IGNPAR - oflag = 0 - cflag |= termios.CLOCAL | termios.CREAD | termios.CS8 - lflag = 0 - ispeed = ospeed = getattr(termios, "B"+str(baud)) - cc[termios.VMIN] = 1 - cc[termios.VTIME] = 0 - termios.tcsetattr(self._fd, termios.TCSANOW, [ - iflag, oflag, cflag, lflag, ispeed, ospeed, cc]) - termios.tcdrain(self._fd) - termios.tcflush(self._fd, termios.TCOFLUSH) - termios.tcflush(self._fd, termios.TCIFLUSH) + def __init__(self, dev="/dev/ttyUSB1", baud=115200): + self._fd = os.open(dev, os.O_RDWR | os.O_NOCTTY) + self.port = os.fdopen(self._fd, "r+b", buffering=0) + iflag, oflag, cflag, lflag, ispeed, ospeed, cc = \ + termios.tcgetattr(self._fd) + iflag = termios.IGNBRK | termios.IGNPAR + oflag = 0 + cflag |= termios.CLOCAL | termios.CREAD | termios.CS8 + lflag = 0 + ispeed = ospeed = getattr(termios, "B"+str(baud)) + cc[termios.VMIN] = 1 + cc[termios.VTIME] = 0 + termios.tcsetattr(self._fd, termios.TCSANOW, [ + iflag, oflag, cflag, lflag, ispeed, ospeed, cc]) + termios.tcdrain(self._fd) + termios.tcflush(self._fd, termios.TCOFLUSH) + termios.tcflush(self._fd, termios.TCIFLUSH) - def close(self): - self.port.close() + def close(self): + self.port.close() - def __enter__(self): - return self + def __enter__(self): + return self - def __exit__(self, type, value, traceback): - self.close() + def __exit__(self, type, value, traceback): + self.close() - def get_runtime_env(self): - _write_exactly(self.port, struct.pack(">lb", 0x5a5a5a5a, _MsgType.REQUEST_IDENT.value)) - # FIXME: when loading immediately after a board reset, we erroneously get some zeros back. - # Ignore them with a warning for now. - spurious_zero_count = 0 - while True: - (reply, ) = struct.unpack("b", _read_exactly(self.port, 1)) - if reply == 0: - spurious_zero_count += 1 - else: - break - if spurious_zero_count: - print("Warning: received {} spurious zeros".format(spurious_zero_count)) - runtime_id = chr(reply) - for i in range(3): - (reply, ) = struct.unpack("b", _read_exactly(self.port, 1)) - runtime_id += chr(reply) - if runtime_id != "AROR": - raise UnsupportedDevice("Unsupported runtime ID: "+runtime_id) - (ref_period, ) = struct.unpack(">l", _read_exactly(self.port, 4)) - return Environment(ref_period*units.ps) + def get_runtime_env(self): + _write_exactly(self.port, struct.pack( + ">lb", 0x5a5a5a5a, _MsgType.REQUEST_IDENT.value)) + # FIXME: when loading immediately after a board reset, + # we erroneously get some zeros back. + # Ignore them with a warning for now. + spurious_zero_count = 0 + while True: + (reply, ) = struct.unpack("b", _read_exactly(self.port, 1)) + if reply == 0: + spurious_zero_count += 1 + else: + break + if spurious_zero_count: + print("Warning: received {} spurious zeros" + .format(spurious_zero_count)) + runtime_id = chr(reply) + for i in range(3): + (reply, ) = struct.unpack("b", _read_exactly(self.port, 1)) + runtime_id += chr(reply) + if runtime_id != "AROR": + raise UnsupportedDevice("Unsupported runtime ID: "+runtime_id) + (ref_period, ) = struct.unpack(">l", _read_exactly(self.port, 4)) + return Environment(ref_period*units.ps) - def run(self, kcode): - _write_exactly(self.port, struct.pack(">lblL", - 0x5a5a5a5a, _MsgType.LOAD_KERNEL.value, len(kcode), zlib.crc32(kcode))) - _write_exactly(self.port, kcode) - (reply, ) = struct.unpack("b", _read_exactly(self.port, 1)) - if reply != 0x4f: - raise IOError("Incorrect reply from device: "+hex(reply)) + def run(self, kcode): + _write_exactly(self.port, struct.pack( + ">lblL", + 0x5a5a5a5a, _MsgType.LOAD_KERNEL.value, + len(kcode), zlib.crc32(kcode))) + _write_exactly(self.port, kcode) + (reply, ) = struct.unpack("b", _read_exactly(self.port, 1)) + if reply != 0x4f: + raise IOError("Incorrect reply from device: "+hex(reply)) - def _wait_sync(self): - recognized = 0 - while recognized < 4: - (c, ) = struct.unpack("b", _read_exactly(self.port, 1)) - if c == 0x5a: - recognized += 1 - else: - recognized = 0 + def _wait_sync(self): + recognized = 0 + while recognized < 4: + (c, ) = struct.unpack("b", _read_exactly(self.port, 1)) + if c == 0x5a: + recognized += 1 + else: + recognized = 0 - def serve(self, rpc_map): - while True: - self._wait_sync() - msg = _MsgType(*struct.unpack("b", _read_exactly(self.port, 1))) - if msg == _MsgType.KERNEL_FINISHED: - return - elif msg == _MsgType.RPC_REQUEST: - rpc_num, n_args = struct.unpack(">hb", _read_exactly(self.port, 3)) - args = [] - for i in range(n_args): - args.append(*struct.unpack(">l", _read_exactly(self.port, 4))) - r = rpc_map[rpc_num](*args) - if r is None: - r = 0 - _write_exactly(self.port, struct.pack(">l", r)) + def serve(self, rpc_map): + while True: + self._wait_sync() + msg = _MsgType(*struct.unpack("b", _read_exactly(self.port, 1))) + if msg == _MsgType.KERNEL_FINISHED: + return + elif msg == _MsgType.RPC_REQUEST: + rpc_num, n_args = struct.unpack(">hb", + _read_exactly(self.port, 3)) + args = [] + for i in range(n_args): + args.append(*struct.unpack(">l", + _read_exactly(self.port, 4))) + r = rpc_map[rpc_num](*args) + if r is None: + r = 0 + _write_exactly(self.port, struct.pack(">l", r)) diff --git a/artiq/devices/dds_core.py b/artiq/devices/dds_core.py index 8c0c66180..b331b73e0 100644 --- a/artiq/devices/dds_core.py +++ b/artiq/devices/dds_core.py @@ -1,20 +1,22 @@ from artiq.language.core import * from artiq.language.units import * + class DDS(AutoContext): - parameters = "dds_sysclk reg_channel rtio_channel" + parameters = "dds_sysclk reg_channel rtio_channel" - def build(self): - self._previous_frequency = 0*MHz + def build(self): + self._previous_frequency = 0*MHz - kernel_attr = "_previous_frequency" + kernel_attr = "_previous_frequency" - @kernel - def pulse(self, frequency, duration): - if self._previous_frequency != frequency: - syscall("rtio_sync", self.rtio_channel) # wait until output is off - syscall("dds_program", self.reg_channel, int(2**32*frequency/self.dds_sysclk)) - self._previous_frequency = frequency - syscall("rtio_set", now(), self.rtio_channel, 1) - delay(duration) - syscall("rtio_set", now(), self.rtio_channel, 0) + @kernel + def pulse(self, frequency, duration): + if self._previous_frequency != frequency: + syscall("rtio_sync", self.rtio_channel) # wait until output is off + syscall("dds_program", self.reg_channel, + int(2**32*frequency/self.dds_sysclk)) + self._previous_frequency = frequency + syscall("rtio_set", now(), self.rtio_channel, 1) + delay(duration) + syscall("rtio_set", now(), self.rtio_channel, 0) diff --git a/artiq/devices/gpio_core.py b/artiq/devices/gpio_core.py index 0574cefc7..d6622b7a7 100644 --- a/artiq/devices/gpio_core.py +++ b/artiq/devices/gpio_core.py @@ -1,8 +1,9 @@ from artiq.language.core import * -class GPIOOut(AutoContext): - parameters = "channel" - @kernel - def set(self, level): - syscall("gpio_set", self.channel, level) +class GPIOOut(AutoContext): + parameters = "channel" + + @kernel + def set(self, level): + syscall("gpio_set", self.channel, level) diff --git a/artiq/devices/runtime.py b/artiq/devices/runtime.py index a134aadd7..5c345236b 100644 --- a/artiq/devices/runtime.py +++ b/artiq/devices/runtime.py @@ -3,70 +3,77 @@ from llvm import target as lt from artiq.compiler import ir_values + lt.initialize_all() _syscalls = { - "rpc": "i+:i", - "gpio_set": "ii:n", - "rtio_set": "Iii:n", - "rtio_sync": "i:n", - "dds_program": "ii:n", + "rpc": "i+:i", + "gpio_set": "ii:n", + "rtio_set": "Iii:n", + "rtio_sync": "i:n", + "dds_program": "ii:n", } _chr_to_type = { - "n": lambda: lc.Type.void(), - "i": lambda: lc.Type.int(32), - "I": lambda: lc.Type.int(64) + "n": lambda: lc.Type.void(), + "i": lambda: lc.Type.int(32), + "I": lambda: lc.Type.int(64) } _chr_to_value = { - "n": lambda: ir_values.VNone(), - "i": lambda: ir_values.VInt(), - "I": lambda: ir_values.VInt(64) + "n": lambda: ir_values.VNone(), + "i": lambda: ir_values.VInt(), + "I": lambda: ir_values.VInt(64) } -def _str_to_functype(s): - assert(s[-2] == ":") - type_ret = _chr_to_type[s[-1]]() - var_arg_fixcount = None - type_args = [] - for n, c in enumerate(s[:-2]): - if c == "+": - type_args.append(lc.Type.int()) - var_arg_fixcount = n - else: - type_args.append(_chr_to_type[c]()) - return var_arg_fixcount, lc.Type.function(type_ret, type_args, var_arg=var_arg_fixcount is not None) +def _str_to_functype(s): + assert(s[-2] == ":") + type_ret = _chr_to_type[s[-1]]() + + var_arg_fixcount = None + type_args = [] + for n, c in enumerate(s[:-2]): + if c == "+": + type_args.append(lc.Type.int()) + var_arg_fixcount = n + else: + type_args.append(_chr_to_type[c]()) + return (var_arg_fixcount, + lc.Type.function(type_ret, type_args, + var_arg=var_arg_fixcount is not None)) + class LinkInterface: - def init_module(self, module): - self.module = module - self.var_arg_fixcount = dict() - for func_name, func_type_str in _syscalls.items(): - var_arg_fixcount, func_type = _str_to_functype(func_type_str) - if var_arg_fixcount is not None: - self.var_arg_fixcount[func_name] = var_arg_fixcount - self.module.add_function(func_type, "__syscall_"+func_name) + def init_module(self, module): + self.module = module + self.var_arg_fixcount = dict() + for func_name, func_type_str in _syscalls.items(): + var_arg_fixcount, func_type = _str_to_functype(func_type_str) + if var_arg_fixcount is not None: + self.var_arg_fixcount[func_name] = var_arg_fixcount + self.module.add_function(func_type, "__syscall_"+func_name) + + def syscall(self, syscall_name, args, builder): + r = _chr_to_value[_syscalls[syscall_name][-1]]() + if builder is not None: + args = [arg.get_ssa_value(builder) for arg in args] + if syscall_name in self.var_arg_fixcount: + fixcount = self.var_arg_fixcount[syscall_name] + args = args[:fixcount] \ + + [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \ + + args[fixcount:] + llvm_function = self.module.get_function_named( + "__syscall_" + syscall_name) + r.set_ssa_value(builder, builder.call(llvm_function, args)) + return r - def syscall(self, syscall_name, args, builder): - r = _chr_to_value[_syscalls[syscall_name][-1]]() - if builder is not None: - args = [arg.get_ssa_value(builder) for arg in args] - if syscall_name in self.var_arg_fixcount: - fixcount = self.var_arg_fixcount[syscall_name] - args = args[:fixcount] \ - + [lc.Constant.int(lc.Type.int(), len(args) - fixcount)] \ - + args[fixcount:] - llvm_function = self.module.get_function_named("__syscall_"+syscall_name) - r.set_ssa_value(builder, builder.call(llvm_function, args)) - return r class Environment(LinkInterface): - def __init__(self, ref_period): - self.ref_period = ref_period - self.initial_time = 2000 + def __init__(self, ref_period): + self.ref_period = ref_period + self.initial_time = 2000 - def emit_object(self): - tm = lt.TargetMachine.new(triple="or1k", cpu="generic") - return tm.emit_object(self.module) + def emit_object(self): + tm = lt.TargetMachine.new(triple="or1k", cpu="generic") + return tm.emit_object(self.module) diff --git a/artiq/devices/ttl_core.py b/artiq/devices/ttl_core.py index 49e697b08..4320e6a9c 100644 --- a/artiq/devices/ttl_core.py +++ b/artiq/devices/ttl_core.py @@ -1,10 +1,11 @@ from artiq.language.core import * -class TTLOut(AutoContext): - parameters = "channel" - @kernel - def pulse(self, duration): - syscall("rtio_set", now(), self.channel, 1) - delay(duration) - syscall("rtio_set", now(), self.channel, 0) +class TTLOut(AutoContext): + parameters = "channel" + + @kernel + def pulse(self, duration): + syscall("rtio_set", now(), self.channel, 1) + delay(duration) + syscall("rtio_set", now(), self.channel, 0) diff --git a/artiq/language/core.py b/artiq/language/core.py index 98b88e475..ac2946036 100644 --- a/artiq/language/core.py +++ b/artiq/language/core.py @@ -3,150 +3,169 @@ from fractions import Fraction from artiq.language import units + class int64(int): - pass + pass def _make_int64_op_method(int_method): - def method(self, *args): - r = int_method(self, *args) - if isinstance(r, int): - r = int64(r) - return r - return method + def method(self, *args): + r = int_method(self, *args) + if isinstance(r, int): + r = int64(r) + return r + return method -for _op_name in ( - "neg", "pos", "abs", "invert", "round", - "add", "radd", "sub", "rsub", "mul", "rmul", "pow", "rpow", - "lshift", "rlshift", "rshift", "rrshift", - "and", "rand", "xor", "rxor", "or", "ror", - "floordiv", "rfloordiv", "mod", "rmod"): - method_name = "__" + _op_name + "__" - orig_method = getattr(int, method_name) - setattr(int64, method_name, _make_int64_op_method(orig_method)) +for _op_name in ("neg", "pos", "abs", "invert", "round", + "add", "radd", "sub", "rsub", "mul", "rmul", "pow", "rpow", + "lshift", "rlshift", "rshift", "rrshift", + "and", "rand", "xor", "rxor", "or", "ror", + "floordiv", "rfloordiv", "mod", "rmod"): + method_name = "__" + _op_name + "__" + orig_method = getattr(int, method_name) + setattr(int64, method_name, _make_int64_op_method(orig_method)) + +for _op_name in ("add", "sub", "mul", "floordiv", "mod", + "pow", "lshift", "rshift", "lshift", + "and", "xor", "or"): + op_method = getattr(int, "__" + _op_name + "__") + setattr(int64, "__i" + _op_name + "__", _make_int64_op_method(op_method)) -for _op_name in ( - "add", "sub", "mul", "floordiv", "mod", - "pow", "lshift", "rshift", "lshift", - "and", "xor", "or"): - op_method = getattr(int, "__" + _op_name + "__") - setattr(int64, "__i" + _op_name + "__", _make_int64_op_method(op_method)) def round64(x): - return int64(round(x)) + return int64(round(x)) + def _make_kernel_ro(value): - return isinstance(value, (bool, int, int64, float, Fraction, units.Quantity)) + return isinstance( + value, (bool, int, int64, float, Fraction, units.Quantity)) + class AutoContext: - parameters = "" - implicit_core = True + parameters = "" + implicit_core = True - def __init__(self, mvs=None, **kwargs): - kernel_attr_ro = [] + def __init__(self, mvs=None, **kwargs): + kernel_attr_ro = [] - self.mvs = mvs - for k, v in kwargs.items(): - setattr(self, k, v) - if _make_kernel_ro(v): - kernel_attr_ro.append(k) + self.mvs = mvs + for k, v in kwargs.items(): + setattr(self, k, v) + if _make_kernel_ro(v): + kernel_attr_ro.append(k) - parameters = self.parameters.split() - if self.implicit_core: - parameters.append("core") - for parameter in parameters: - try: - value = getattr(self, parameter) - except AttributeError: - value = self.mvs.get_missing_value(parameter) - setattr(self, parameter, value) - if _make_kernel_ro(value): - kernel_attr_ro.append(parameter) - - self.kernel_attr_ro = " ".join(kernel_attr_ro) + parameters = self.parameters.split() + if self.implicit_core: + parameters.append("core") + for parameter in parameters: + try: + value = getattr(self, parameter) + except AttributeError: + value = self.mvs.get_missing_value(parameter) + setattr(self, parameter, value) + if _make_kernel_ro(value): + kernel_attr_ro.append(parameter) - self.build() + self.kernel_attr_ro = " ".join(kernel_attr_ro) - def get_missing_value(self, parameter): - try: - return getattr(self, parameter) - except AttributeError: - return self.mvs.get_missing_value(parameter) + self.build() + + def get_missing_value(self, parameter): + try: + return getattr(self, parameter) + except AttributeError: + return self.mvs.get_missing_value(parameter) + + def build(self): + """ Overload this function to add sub-experiments""" + pass - def build(self): - """ Overload this function to add sub-experiments""" - pass KernelFunctionInfo = namedtuple("KernelFunctionInfo", "core_name k_function") + def kernel(arg): - if isinstance(arg, str): - def real_decorator(k_function): - def run_on_core(exp, *k_args, **k_kwargs): - getattr(exp, arg).run(k_function, ((exp,) + k_args), k_kwargs) - run_on_core.k_function_info = KernelFunctionInfo(core_name=arg, k_function=k_function) - return run_on_core - return real_decorator - else: - def run_on_core(exp, *k_args, **k_kwargs): - exp.core.run(arg, ((exp,) + k_args), k_kwargs) - run_on_core.k_function_info = KernelFunctionInfo(core_name="core", k_function=arg) - return run_on_core + if isinstance(arg, str): + def real_decorator(k_function): + def run_on_core(exp, *k_args, **k_kwargs): + getattr(exp, arg).run(k_function, ((exp,) + k_args), k_kwargs) + run_on_core.k_function_info = KernelFunctionInfo( + core_name=arg, k_function=k_function) + return run_on_core + return real_decorator + else: + def run_on_core(exp, *k_args, **k_kwargs): + exp.core.run(arg, ((exp,) + k_args), k_kwargs) + run_on_core.k_function_info = KernelFunctionInfo( + core_name="core", k_function=arg) + return run_on_core + class _DummyTimeManager: - def _not_implemented(self, *args, **kwargs): - raise NotImplementedError("Attempted to interpret kernel without a time manager") + def _not_implemented(self, *args, **kwargs): + raise NotImplementedError( + "Attempted to interpret kernel without a time manager") - enter_sequential = _not_implemented - enter_parallel = _not_implemented - exit = _not_implemented - take_time = _not_implemented - get_time = _not_implemented - set_time = _not_implemented + enter_sequential = _not_implemented + enter_parallel = _not_implemented + exit = _not_implemented + take_time = _not_implemented + get_time = _not_implemented + set_time = _not_implemented _time_manager = _DummyTimeManager() + def set_time_manager(time_manager): - global _time_manager - _time_manager = time_manager + global _time_manager + _time_manager = time_manager + class _DummySyscallManager: - def do(self, *args): - raise NotImplementedError("Attempted to interpret kernel without a syscall manager") + def do(self, *args): + raise NotImplementedError( + "Attempted to interpret kernel without a syscall manager") _syscall_manager = _DummySyscallManager() + def set_syscall_manager(syscall_manager): - global _syscall_manager - _syscall_manager = syscall_manager + global _syscall_manager + _syscall_manager = syscall_manager # global namespace for kernels kernel_globals = "sequential", "parallel", "delay", "now", "at", "syscall" -class _Sequential: - def __enter__(self): - _time_manager.enter_sequential() - def __exit__(self, type, value, traceback): - _time_manager.exit() +class _Sequential: + def __enter__(self): + _time_manager.enter_sequential() + + def __exit__(self, type, value, traceback): + _time_manager.exit() sequential = _Sequential() -class _Parallel: - def __enter__(self): - _time_manager.enter_parallel() - def __exit__(self, type, value, traceback): - _time_manager.exit() +class _Parallel: + def __enter__(self): + _time_manager.enter_parallel() + + def __exit__(self, type, value, traceback): + _time_manager.exit() parallel = _Parallel() + def delay(duration): - _time_manager.take_time(duration) + _time_manager.take_time(duration) + def now(): - return _time_manager.get_time() + return _time_manager.get_time() + def at(time): - _time_manager.set_time(time) + _time_manager.set_time(time) + def syscall(*args): - return _syscall_manager.do(*args) + return _syscall_manager.do(*args) diff --git a/artiq/language/units.py b/artiq/language/units.py index a56ab7279..07be3a73f 100644 --- a/artiq/language/units.py +++ b/artiq/language/units.py @@ -1,122 +1,139 @@ from collections import namedtuple from fractions import Fraction + _prefixes_str = "pnum_kMG" _smallest_prefix = Fraction(1, 10**12) Unit = namedtuple("Unit", "name") + class DimensionError(Exception): - pass + pass + class Quantity: - def __init__(self, amount, unit): - self.amount = amount - self.unit = unit + def __init__(self, amount, unit): + self.amount = amount + self.unit = unit - def __repr__(self): - r_amount = self.amount - if isinstance(r_amount, int) or isinstance(r_amount, Fraction): - r_prefix = 0 - r_amount = r_amount/_smallest_prefix - if r_amount: - numerator = r_amount.numerator - while numerator % 1000 == 0 and r_prefix < len(_prefixes_str): - numerator /= 1000 - r_amount /= 1000 - r_prefix += 1 - prefix_str = _prefixes_str[r_prefix] - if prefix_str == "_": - prefix_str = "" - return str(r_amount) + " " + prefix_str + self.unit.name - else: - return str(r_amount) + " " + self.unit.name + def __repr__(self): + r_amount = self.amount + if isinstance(r_amount, int) or isinstance(r_amount, Fraction): + r_prefix = 0 + r_amount = r_amount/_smallest_prefix + if r_amount: + numerator = r_amount.numerator + while numerator % 1000 == 0 and r_prefix < len(_prefixes_str): + numerator /= 1000 + r_amount /= 1000 + r_prefix += 1 + prefix_str = _prefixes_str[r_prefix] + if prefix_str == "_": + prefix_str = "" + return str(r_amount) + " " + prefix_str + self.unit.name + else: + return str(r_amount) + " " + self.unit.name - def __mul__(self, other): - if isinstance(other, Quantity): - return NotImplemented - return Quantity(self.amount*other, self.unit) - def __rmul__(self, other): - if isinstance(other, Quantity): - return NotImplemented - return Quantity(other*self.amount, self.unit) - def __truediv__(self, other): - if isinstance(other, Quantity): - if other.unit == self.unit: - return self.amount/other.amount - else: - return NotImplemented - else: - return Quantity(self.amount/other, self.unit) - def __floordiv__(self, other): - if isinstance(other, Quantity): - if other.unit == self.unit: - return self.amount//other.amount - else: - return NotImplemented - else: - return Quantity(self.amount//other, self.unit) + # mul/div + def __mul__(self, other): + if isinstance(other, Quantity): + return NotImplemented + return Quantity(self.amount*other, self.unit) - def __neg__(self): - return Quantity(-self.amount, self.unit) + def __rmul__(self, other): + if isinstance(other, Quantity): + return NotImplemented + return Quantity(other*self.amount, self.unit) - def __add__(self, other): - if self.unit != other.unit: - raise DimensionError - return Quantity(self.amount + other.amount, self.unit) - def __radd__(self, other): - if self.unit != other.unit: - raise DimensionError - return Quantity(other.amount + self.amount, self.unit) - def __sub__(self, other): - if self.unit != other.unit: - raise DimensionError - return Quantity(self.amount - other.amount, self.unit) - def __rsub__(self, other): - if self.unit != other.unit: - raise DimensionError - return Quantity(other.amount - self.amount, self.unit) + def __truediv__(self, other): + if isinstance(other, Quantity): + if other.unit == self.unit: + return self.amount/other.amount + else: + return NotImplemented + else: + return Quantity(self.amount/other, self.unit) - def __lt__(self, other): - if self.unit != other.unit: - raise DimensionError - return self.amount < other.amount - def __le__(self, other): - if self.unit != other.unit: - raise DimensionError - return self.amount <= other.amount - def __eq__(self, other): - if self.unit != other.unit: - raise DimensionError - return self.amount == other.amount - def __ne__(self, other): - if self.unit != other.unit: - raise DimensionError - return self.amount != other.amount - def __gt__(self, other): - if self.unit != other.unit: - raise DimensionError - return self.amount > other.amount - def __ge__(self, other): - if self.unit != other.unit: - raise DimensionError - return self.amount >= other.amount + def __floordiv__(self, other): + if isinstance(other, Quantity): + if other.unit == self.unit: + return self.amount//other.amount + else: + return NotImplemented + else: + return Quantity(self.amount//other, self.unit) + + # unary ops + def __neg__(self): + return Quantity(-self.amount, self.unit) + + def __pos__(self): + return Quantity(self.amount, self.unit) + + # add/sub + def __add__(self, other): + if self.unit != other.unit: + raise DimensionError + return Quantity(self.amount + other.amount, self.unit) + + def __radd__(self, other): + if self.unit != other.unit: + raise DimensionError + return Quantity(other.amount + self.amount, self.unit) + + def __sub__(self, other): + if self.unit != other.unit: + raise DimensionError + return Quantity(self.amount - other.amount, self.unit) + + def __rsub__(self, other): + if self.unit != other.unit: + raise DimensionError + return Quantity(other.amount - self.amount, self.unit) + + # comparisons + def __lt__(self, other): + if self.unit != other.unit: + raise DimensionError + return self.amount < other.amount + + def __le__(self, other): + if self.unit != other.unit: + raise DimensionError + return self.amount <= other.amount + + def __eq__(self, other): + if self.unit != other.unit: + raise DimensionError + return self.amount == other.amount + + def __ne__(self, other): + if self.unit != other.unit: + raise DimensionError + return self.amount != other.amount + + def __gt__(self, other): + if self.unit != other.unit: + raise DimensionError + return self.amount > other.amount + + def __ge__(self, other): + if self.unit != other.unit: + raise DimensionError + return self.amount >= other.amount -def check_unit(value, unit): - if not isinstance(value, Quantity) or value.unit != unit: - raise DimensionError - return value.amount def _register_unit(name, prefixes): - unit = Unit(name) - globals()[name+"_unit"] = unit - amount = _smallest_prefix - for prefix in _prefixes_str: - if prefix in prefixes: - quantity = Quantity(amount, unit) - full_name = prefix + name if prefix != "_" else name - globals()[full_name] = quantity - amount *= 1000 + unit = Unit(name) + globals()[name+"_unit"] = unit + amount = _smallest_prefix + for prefix in _prefixes_str: + if prefix in prefixes: + quantity = Quantity(amount, unit) + full_name = prefix + name if prefix != "_" else name + globals()[full_name] = quantity + amount *= 1000 _register_unit("s", "pnum_") _register_unit("Hz", "_kMG") diff --git a/artiq/sim/devices.py b/artiq/sim/devices.py index 042f98475..ec958455c 100644 --- a/artiq/sim/devices.py +++ b/artiq/sim/devices.py @@ -4,39 +4,43 @@ from artiq.language.core import AutoContext, delay from artiq.language import units from artiq.sim import time + class Core: - def run(self, k_function, k_args, k_kwargs): - return k_function(*k_args, **k_kwargs) + def run(self, k_function, k_args, k_kwargs): + return k_function(*k_args, **k_kwargs) + class Input(AutoContext): - parameters = "name" - implicit_core = False + parameters = "name" + implicit_core = False - def build(self): - self.prng = Random() + def build(self): + self.prng = Random() - def wait_edge(self): - duration = self.prng.randrange(0, 20)*units.ms - time.manager.event(("wait_edge", self.name, duration)) - delay(duration) + def wait_edge(self): + duration = self.prng.randrange(0, 20)*units.ms + time.manager.event(("wait_edge", self.name, duration)) + delay(duration) + + def count_gate(self, duration): + result = self.prng.randrange(0, 100) + time.manager.event(("count_gate", self.name, duration, result)) + delay(duration) + return result - def count_gate(self, duration): - result = self.prng.randrange(0, 100) - time.manager.event(("count_gate", self.name, duration, result)) - delay(duration) - return result class WaveOutput(AutoContext): - parameters = "name" - implicit_core = False + parameters = "name" + implicit_core = False + + def pulse(self, frequency, duration): + time.manager.event(("pulse", self.name, frequency, duration)) + delay(duration) - def pulse(self, frequency, duration): - time.manager.event(("pulse", self.name, frequency, duration)) - delay(duration) class VoltageOutput(AutoContext): - parameters = "name" - implicit_core = False + parameters = "name" + implicit_core = False - def set(self, value): - time.manager.event(("set_voltage", self.name, value)) + def set(self, value): + time.manager.event(("set_voltage", self.name, value)) diff --git a/artiq/sim/time.py b/artiq/sim/time.py index 05841ae52..3ae01a073 100644 --- a/artiq/sim/time.py +++ b/artiq/sim/time.py @@ -3,66 +3,69 @@ from operator import itemgetter from artiq.language.units import * from artiq.language import core as core_language -class SequentialTimeContext: - def __init__(self, current_time): - self.current_time = current_time - self.block_duration = 0*s - def take_time(self, amount): - self.current_time += amount - self.block_duration += amount +class SequentialTimeContext: + def __init__(self, current_time): + self.current_time = current_time + self.block_duration = 0*s + + def take_time(self, amount): + self.current_time += amount + self.block_duration += amount + class ParallelTimeContext: - def __init__(self, current_time): - self.current_time = current_time - self.block_duration = 0*s + def __init__(self, current_time): + self.current_time = current_time + self.block_duration = 0*s + + def take_time(self, amount): + if amount > self.block_duration: + self.block_duration = amount - def take_time(self, amount): - if amount > self.block_duration: - self.block_duration = amount class Manager: - def __init__(self): - self.stack = [SequentialTimeContext(0*s)] - self.timeline = [] + def __init__(self): + self.stack = [SequentialTimeContext(0*s)] + self.timeline = [] - def enter_sequential(self): - new_context = SequentialTimeContext(self.get_time()) - self.stack.append(new_context) + def enter_sequential(self): + new_context = SequentialTimeContext(self.get_time()) + self.stack.append(new_context) - def enter_parallel(self): - new_context = ParallelTimeContext(self.get_time()) - self.stack.append(new_context) + def enter_parallel(self): + new_context = ParallelTimeContext(self.get_time()) + self.stack.append(new_context) - def exit(self): - old_context = self.stack.pop() - self.take_time(old_context.block_duration) + def exit(self): + old_context = self.stack.pop() + self.take_time(old_context.block_duration) - def take_time(self, duration): - self.stack[-1].take_time(duration) + def take_time(self, duration): + self.stack[-1].take_time(duration) - def get_time(self): - return self.stack[-1].current_time + def get_time(self): + return self.stack[-1].current_time - def set_time(self, t): - dt = t - self.get_time() - if dt < 0*s: - raise ValueError("Attempted to go back in time") - self.take_time(dt) + def set_time(self, t): + dt = t - self.get_time() + if dt < 0*s: + raise ValueError("Attempted to go back in time") + self.take_time(dt) - def event(self, description): - self.timeline.append((self.get_time(), description)) + def event(self, description): + self.timeline.append((self.get_time(), description)) - def format_timeline(self): - r = "" - prev_time = 0*s - for time, description in sorted(self.timeline, key=itemgetter(0)): - r += "@{:10} (+{:10}) ".format(str(time), str(time-prev_time)) - for item in description: - r += "{:16}".format(str(item)) - r += "\n" - prev_time = time - return r + def format_timeline(self): + r = "" + prev_time = 0*s + for time, description in sorted(self.timeline, key=itemgetter(0)): + r += "@{:10} (+{:10}) ".format(str(time), str(time-prev_time)) + for item in description: + r += "{:16}".format(str(item)) + r += "\n" + prev_time = time + return r manager = Manager() core_language.set_time_manager(manager) diff --git a/examples/al_spectroscopy.py b/examples/al_spectroscopy.py index ee73d3926..73b95b12d 100644 --- a/examples/al_spectroscopy.py +++ b/examples/al_spectroscopy.py @@ -1,51 +1,54 @@ from artiq.language.units import * from artiq.language.core import * -class AluminumSpectroscopy(AutoContext): - parameters = "mains_sync laser_cooling spectroscopy spectroscopy_b state_detection pmt \ - spectroscopy_freq photon_limit_low photon_limit_high" - @kernel - def run(self): - state_0_count = 0 - for count in range(100): - self.mains_sync.wait_edge() - delay(10*us) - self.laser_cooling.pulse(100*MHz, 100*us) - delay(5*us) - with parallel: - self.spectroscopy.pulse(self.spectroscopy_freq, 100*us) - with sequential: - delay(50*us) - self.spectroscopy_b.set(200) - delay(5*us) - while True: - delay(5*us) - with parallel: - self.state_detection.pulse(100*MHz, 10*us) - photon_count = self.pmt.count_gate(10*us) - if photon_count < self.photon_limit_low or photon_count > self.photon_limit_high: - break - if photon_count < self.photon_limit_low: - state_0_count += 1 - return state_0_count +class AluminumSpectroscopy(AutoContext): + parameters = "mains_sync laser_cooling spectroscopy spectroscopy_b state_detection pmt \ + spectroscopy_freq photon_limit_low photon_limit_high" + + @kernel + def run(self): + state_0_count = 0 + for count in range(100): + self.mains_sync.wait_edge() + delay(10*us) + self.laser_cooling.pulse(100*MHz, 100*us) + delay(5*us) + with parallel: + self.spectroscopy.pulse(self.spectroscopy_freq, 100*us) + with sequential: + delay(50*us) + self.spectroscopy_b.set(200) + delay(5*us) + while True: + delay(5*us) + with parallel: + self.state_detection.pulse(100*MHz, 10*us) + photon_count = self.pmt.count_gate(10*us) + if (photon_count < self.photon_limit_low + or photon_count > self.photon_limit_high): + break + if photon_count < self.photon_limit_low: + state_0_count += 1 + return state_0_count + if __name__ == "__main__": - from artiq.sim import devices as sd - from artiq.sim import time + from artiq.sim import devices as sd + from artiq.sim import time - exp = AluminumSpectroscopy( - core=sd.Core(), - mains_sync=sd.Input(name="mains_sync"), - laser_cooling=sd.WaveOutput(name="laser_cooling"), - spectroscopy=sd.WaveOutput(name="spectroscopy"), - spectroscopy_b=sd.VoltageOutput(name="spectroscopy_b"), - state_detection=sd.WaveOutput(name="state_detection"), - pmt=sd.Input(name="pmt"), + exp = AluminumSpectroscopy( + core=sd.Core(), + mains_sync=sd.Input(name="mains_sync"), + laser_cooling=sd.WaveOutput(name="laser_cooling"), + spectroscopy=sd.WaveOutput(name="spectroscopy"), + spectroscopy_b=sd.VoltageOutput(name="spectroscopy_b"), + state_detection=sd.WaveOutput(name="state_detection"), + pmt=sd.Input(name="pmt"), - spectroscopy_freq=432*MHz, - photon_limit_low=10, - photon_limit_high=15 - ) - exp.run() - print(time.manager.format_timeline()) + spectroscopy_freq=432*MHz, + photon_limit_low=10, + photon_limit_high=15 + ) + exp.run() + print(time.manager.format_timeline()) diff --git a/examples/compiler_test.py b/examples/compiler_test.py index c247db045..f7856fa0c 100644 --- a/examples/compiler_test.py +++ b/examples/compiler_test.py @@ -1,41 +1,48 @@ from artiq.language.units import * from artiq.language.core import * + my_range = range + class CompilerTest(AutoContext): - parameters = "a b A B" + parameters = "a b A B" - def print_done(self): - print("Done!") + def print_done(self): + print("Done!") - def set_some_slowdev(self, n): - print("Slow device setting: {}".format(n)) + def set_some_slowdev(self, n): + print("Slow device setting: {}".format(n)) + + @kernel + def run(self, n, t2): + for i in my_range(n): + self.set_some_slowdev(i) + delay(100*ms) + with parallel: + with sequential: + for j in my_range(3): + self.a.pulse((j+1)*100*MHz, 20*us) + self.b.pulse(100*MHz, t2) + with sequential: + self.A.pulse(100*MHz, 10*us) + self.B.pulse(100*MHz, t2) + self.print_done() - @kernel - def run(self, n, t2): - for i in my_range(n): - self.set_some_slowdev(i) - delay(100*ms) - with parallel: - with sequential: - for j in my_range(3): - self.a.pulse((j+1)*100*MHz, 20*us) - self.b.pulse(100*MHz, t2) - with sequential: - self.A.pulse(100*MHz, 10*us) - self.B.pulse(100*MHz, t2) - self.print_done() if __name__ == "__main__": - from artiq.devices import corecom_dummy, core, dds_core + from artiq.devices import corecom_dummy, core, dds_core - coredev = core.Core(corecom_dummy.CoreCom()) - exp = CompilerTest( - core=coredev, - a=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=0, rtio_channel=0), - b=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=1, rtio_channel=1), - A=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=2, rtio_channel=2), - B=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=3, rtio_channel=3) - ) - exp.run(3, 100*us) + coredev = core.Core(corecom_dummy.CoreCom()) + exp = CompilerTest( + core=coredev, + a=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, + reg_channel=0, rtio_channel=0), + b=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, + reg_channel=1, rtio_channel=1), + A=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, + reg_channel=2, rtio_channel=2), + B=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, + reg_channel=3, rtio_channel=3) + ) + exp.run(3, 100*us) diff --git a/examples/coredev_test.py b/examples/coredev_test.py index cde57880d..f5e8cef1e 100644 --- a/examples/coredev_test.py +++ b/examples/coredev_test.py @@ -1,37 +1,39 @@ from artiq.language.core import AutoContext, kernel from artiq.devices import corecom_serial, core, gpio_core + class CompilerTest(AutoContext): - parameters = "led" + parameters = "led" - def output(self, n): - print("Received: "+str(n)) + def output(self, n): + print("Received: "+str(n)) - def get_max(self): - return int(input("Maximum: ")) + def get_max(self): + return int(input("Maximum: ")) + + @kernel + def run(self): + self.led.set(1) + x = 1 + m = self.get_max() + while x < m: + d = 2 + prime = True + while d*d <= x: + if x % d == 0: + prime = False + d += 1 + if prime: + self.output(x) + x += 1 + self.led.set(0) - @kernel - def run(self): - self.led.set(1) - x = 1 - m = self.get_max() - while x < m: - d = 2 - prime = True - while d*d <= x: - if x % d == 0: - prime = False - d += 1 - if prime: - self.output(x) - x += 1 - self.led.set(0) if __name__ == "__main__": - with corecom_serial.CoreCom() as com: - coredev = core.Core(com) - exp = CompilerTest( - core=coredev, - led=gpio_core.GPIOOut(core=coredev, channel=0) - ) - exp.run() + with corecom_serial.CoreCom() as com: + coredev = core.Core(com) + exp = CompilerTest( + core=coredev, + led=gpio_core.GPIOOut(core=coredev, channel=0) + ) + exp.run() diff --git a/examples/dds_test.py b/examples/dds_test.py index 5e8fa3137..d630dabdd 100644 --- a/examples/dds_test.py +++ b/examples/dds_test.py @@ -2,36 +2,42 @@ from artiq.language.units import * from artiq.language.core import * from artiq.devices import corecom_serial, core, dds_core, gpio_core -class DDSTest(AutoContext): - parameters = "a b c d led" - @kernel - def run(self): - i = 0 - while i < 10000: - if i & 0x200: - self.led.set(1) - else: - self.led.set(0) - with parallel: - with sequential: - self.a.pulse(100*MHz + 4*i*kHz, 500*us) - self.b.pulse(120*MHz, 500*us) - with sequential: - self.c.pulse(200*MHz, 100*us) - self.d.pulse(250*MHz, 200*us) - i += 1 - self.led.set(0) +class DDSTest(AutoContext): + parameters = "a b c d led" + + @kernel + def run(self): + i = 0 + while i < 10000: + if i & 0x200: + self.led.set(1) + else: + self.led.set(0) + with parallel: + with sequential: + self.a.pulse(100*MHz + 4*i*kHz, 500*us) + self.b.pulse(120*MHz, 500*us) + with sequential: + self.c.pulse(200*MHz, 100*us) + self.d.pulse(250*MHz, 200*us) + i += 1 + self.led.set(0) + if __name__ == "__main__": - with corecom_serial.CoreCom() as com: - coredev = core.Core(com) - exp = DDSTest( - core=coredev, - a=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=0, rtio_channel=0), - b=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=1, rtio_channel=1), - c=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=2, rtio_channel=2), - d=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, reg_channel=3, rtio_channel=3), - led=gpio_core.GPIOOut(core=coredev, channel=1) - ) - exp.run() + with corecom_serial.CoreCom() as com: + coredev = core.Core(com) + exp = DDSTest( + core=coredev, + a=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, + reg_channel=0, rtio_channel=0), + b=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, + reg_channel=1, rtio_channel=1), + c=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, + reg_channel=2, rtio_channel=2), + d=dds_core.DDS(core=coredev, dds_sysclk=1*GHz, + reg_channel=3, rtio_channel=3), + led=gpio_core.GPIOOut(core=coredev, channel=1) + ) + exp.run() diff --git a/examples/simple_simulation.py b/examples/simple_simulation.py index a9183302b..5d1142233 100644 --- a/examples/simple_simulation.py +++ b/examples/simple_simulation.py @@ -1,29 +1,31 @@ from artiq.language.units import * from artiq.language.core import * -class SimpleSimulation(AutoContext): - parameters = "a b c d" - @kernel - def run(self): - with parallel: - with sequential: - self.a.pulse(100*MHz, 20*us) - self.b.pulse(200*MHz, 20*us) - with sequential: - self.c.pulse(300*MHz, 10*us) - self.d.pulse(400*MHz, 20*us) +class SimpleSimulation(AutoContext): + parameters = "a b c d" + + @kernel + def run(self): + with parallel: + with sequential: + self.a.pulse(100*MHz, 20*us) + self.b.pulse(200*MHz, 20*us) + with sequential: + self.c.pulse(300*MHz, 10*us) + self.d.pulse(400*MHz, 20*us) + if __name__ == "__main__": - from artiq.sim import devices as sd - from artiq.sim import time + from artiq.sim import devices as sd + from artiq.sim import time - exp = SimpleSimulation( - core=sd.Core(), - a=sd.WaveOutput(name="a"), - b=sd.WaveOutput(name="b"), - c=sd.WaveOutput(name="c"), - d=sd.WaveOutput(name="d"), - ) - exp.run() - print(time.manager.format_timeline()) + exp = SimpleSimulation( + core=sd.Core(), + a=sd.WaveOutput(name="a"), + b=sd.WaveOutput(name="b"), + c=sd.WaveOutput(name="c"), + d=sd.WaveOutput(name="d"), + ) + exp.run() + print(time.manager.format_timeline()) diff --git a/examples/time_test.py b/examples/time_test.py index 54b66b14e..0d760abe9 100644 --- a/examples/time_test.py +++ b/examples/time_test.py @@ -2,45 +2,48 @@ from artiq.language.units import * from artiq.language.core import * from artiq.devices import corecom_serial, core + class DummyPulse(AutoContext): - parameters = "name" + parameters = "name" - def print_on(self, t, f): - print("{} ON:{:4} @{}".format(self.name, f, t)) + def print_on(self, t, f): + print("{} ON:{:4} @{}".format(self.name, f, t)) - def print_off(self, t): - print("{} OFF @{}".format(self.name, t)) + def print_off(self, t): + print("{} OFF @{}".format(self.name, t)) + + @kernel + def pulse(self, f, duration): + self.print_on(int(now()), f) + delay(duration) + self.print_off(int(now())) - @kernel - def pulse(self, f, duration): - self.print_on(int(now()), f) - delay(duration) - self.print_off(int(now())) class TimeTest(AutoContext): - parameters = "a b c d" + parameters = "a b c d" + + @kernel + def run(self): + i = 0 + while i < 3: + with parallel: + with sequential: + self.a.pulse(100+i, 20*us) + self.b.pulse(200+i, 20*us) + with sequential: + self.c.pulse(300+i, 10*us) + self.d.pulse(400+i, 20*us) + i += 1 - @kernel - def run(self): - i = 0 - while i < 3: - with parallel: - with sequential: - self.a.pulse(100+i, 20*us) - self.b.pulse(200+i, 20*us) - with sequential: - self.c.pulse(300+i, 10*us) - self.d.pulse(400+i, 20*us) - i += 1 if __name__ == "__main__": - with corecom_serial.CoreCom() as com: - coredev = core.Core(com) - exp = TimeTest( - core=coredev, - a=DummyPulse(core=coredev, name="a"), - b=DummyPulse(core=coredev, name="b"), - c=DummyPulse(core=coredev, name="c"), - d=DummyPulse(core=coredev, name="d"), - ) - exp.run() + with corecom_serial.CoreCom() as com: + coredev = core.Core(com) + exp = TimeTest( + core=coredev, + a=DummyPulse(core=coredev, name="a"), + b=DummyPulse(core=coredev, name="b"), + c=DummyPulse(core=coredev, name="c"), + d=DummyPulse(core=coredev, name="d"), + ) + exp.run() diff --git a/soc/artiqlib/ad9858/__init__.py b/soc/artiqlib/ad9858/__init__.py index 1fcb40853..217113caa 100644 --- a/soc/artiqlib/ad9858/__init__.py +++ b/soc/artiqlib/ad9858/__init__.py @@ -4,192 +4,197 @@ from migen.bus import wishbone from migen.bus.transactions import * from migen.sim.generic import run_simulation + class AD9858(Module): - """Wishbone interface to the AD9858 DDS chip. + """Wishbone interface to the AD9858 DDS chip. - Addresses 0-63 map the AD9858 registers. - Data is zero-padded. + Addresses 0-63 map the AD9858 registers. + Data is zero-padded. - Write to address 64 to pulse the FUD signal. - Address 65 is a GPIO register that controls the sel, p and reset signals. - sel is mapped to the lower bits, followed by p and reset. + Write to address 64 to pulse the FUD signal. + Address 65 is a GPIO register that controls the sel, p and reset signals. + sel is mapped to the lower bits, followed by p and reset. - Write timing: - Address is set one cycle before assertion of we_n. - we_n is asserted for one cycle, at the same time as valid data is driven. + Write timing: + Address is set one cycle before assertion of we_n. + we_n is asserted for one cycle, at the same time as valid data is driven. - Read timing: - Address is set one cycle before assertion of rd_n. - rd_n is asserted for 3 cycles. - Data is sampled 2 cycles into the assertion of rd_n. + Read timing: + Address is set one cycle before assertion of rd_n. + rd_n is asserted for 3 cycles. + Data is sampled 2 cycles into the assertion of rd_n. - Design: - All IO pads are registered. + Design: + All IO pads are registered. - LVDS driver/receiver propagation delays are 3.6+4.5 ns max - LVDS state transition delays are 20, 15 ns max - Schmitt trigger delays are 6.4ns max - Round-trip addr A setup (> RX, RD, D to Z), RD prop, D valid (< D - valid), D prop is ~15 + 10 + 20 + 10 = 55ns - """ - def __init__(self, pads, bus=None): - if bus is None: - bus = wishbone.Interface() - self.bus = bus + LVDS driver/receiver propagation delays are 3.6+4.5 ns max + LVDS state transition delays are 20, 15 ns max + Schmitt trigger delays are 6.4ns max + Round-trip addr A setup (> RX, RD, D to Z), RD prop, D valid (< D + valid), D prop is ~15 + 10 + 20 + 10 = 55ns + """ + def __init__(self, pads, bus=None): + if bus is None: + bus = wishbone.Interface() + self.bus = bus - ### + # # # - dts = TSTriple(8) - self.specials += dts.get_tristate(pads.d) - dr = Signal(8) - rx = Signal() - self.sync += [ - pads.a.eq(bus.adr), - dts.o.eq(bus.dat_w), - dr.eq(dts.i), - dts.oe.eq(~rx) - ] + dts = TSTriple(8) + self.specials += dts.get_tristate(pads.d) + dr = Signal(8) + rx = Signal() + self.sync += [ + pads.a.eq(bus.adr), + dts.o.eq(bus.dat_w), + dr.eq(dts.i), + dts.oe.eq(~rx) + ] - gpio = Signal(flen(pads.sel) + flen(pads.p) + 1) - gpio_load = Signal() - self.sync += If(gpio_load, gpio.eq(bus.dat_w)) - self.comb += [ - Cat(pads.sel, pads.p).eq(gpio), - pads.rst_n.eq(~gpio[-1]), - ] + gpio = Signal(flen(pads.sel) + flen(pads.p) + 1) + gpio_load = Signal() + self.sync += If(gpio_load, gpio.eq(bus.dat_w)) + self.comb += [ + Cat(pads.sel, pads.p).eq(gpio), + pads.rst_n.eq(~gpio[-1]), + ] - bus_r_gpio = Signal() - self.comb += If(bus_r_gpio, - bus.dat_r.eq(gpio) - ).Else( - bus.dat_r.eq(dr) - ) + bus_r_gpio = Signal() + self.comb += If(bus_r_gpio, + bus.dat_r.eq(gpio) + ).Else( + bus.dat_r.eq(dr) + ) - fud = Signal() - self.sync += pads.fud_n.eq(~fud) + fud = Signal() + self.sync += pads.fud_n.eq(~fud) - pads.wr_n.reset = 1 - pads.rd_n.reset = 1 - wr = Signal() - rd = Signal() - self.sync += pads.wr_n.eq(~wr), pads.rd_n.eq(~rd) + pads.wr_n.reset = 1 + pads.rd_n.reset = 1 + wr = Signal() + rd = Signal() + self.sync += pads.wr_n.eq(~wr), pads.rd_n.eq(~rd) - fsm = FSM("IDLE") - self.submodules += fsm + fsm = FSM("IDLE") + self.submodules += fsm + + fsm.act("IDLE", + If(bus.cyc & bus.stb, + If(bus.adr[6], + If(bus.adr[0], + NextState("GPIO") + ).Else( + NextState("FUD") + ) + ).Else( + If(bus.we, + NextState("WRITE") + ).Else( + NextState("READ") + ) + ) + ) + ) + fsm.act("WRITE", + # 3ns A setup to WR active + wr.eq(1), + NextState("WRITE0") + ) + fsm.act("WRITE0", + # 3.5ns D setup to WR inactive + # 0ns D and A hold to WR inactive + bus.ack.eq(1), + NextState("IDLE") + ) + fsm.act("READ", + # 15ns D valid to A setup + # 15ns D valid to RD active + rx.eq(1), + rd.eq(1), + NextState("READ0") + ) + fsm.act("READ0", + rx.eq(1), + rd.eq(1), + NextState("READ1") + ) + fsm.act("READ1", + rx.eq(1), + rd.eq(1), + NextState("READ2") + ) + fsm.act("READ2", + rx.eq(1), + rd.eq(1), + NextState("READ3") + ) + fsm.act("READ3", + rx.eq(1), + rd.eq(1), + NextState("READ4") + ) + fsm.act("READ4", + rx.eq(1), + NextState("READ5") + ) + fsm.act("READ5", + # 5ns D three-state to RD inactive + # 10ns A hold to RD inactive + rx.eq(1), + bus.ack.eq(1), + NextState("IDLE") + ) + fsm.act("GPIO", + bus.ack.eq(1), + bus_r_gpio.eq(1), + If(bus.we, gpio_load.eq(1)), + NextState("IDLE") + ) + fsm.act("FUD", + # 4ns FUD setup to SYNCLK + # 0ns FUD hold to SYNCLK + fud.eq(1), + bus.ack.eq(1), + NextState("IDLE") + ) - fsm.act("IDLE", - If(bus.cyc & bus.stb, - If(bus.adr[6], - If(bus.adr[0], - NextState("GPIO") - ).Else( - NextState("FUD") - ) - ).Else( - If(bus.we, - NextState("WRITE") - ).Else( - NextState("READ") - ) - ) - ) - ) - fsm.act("WRITE", - # 3ns A setup to WR active - wr.eq(1), - NextState("WRITE0") - ) - fsm.act("WRITE0", - # 3.5ns D setup to WR inactive - # 0ns D and A hold to WR inactive - bus.ack.eq(1), - NextState("IDLE") - ) - fsm.act("READ", - # 15ns D valid to A setup - # 15ns D valid to RD active - rx.eq(1), - rd.eq(1), - NextState("READ0") - ) - fsm.act("READ0", - rx.eq(1), - rd.eq(1), - NextState("READ1") - ) - fsm.act("READ1", - rx.eq(1), - rd.eq(1), - NextState("READ2") - ) - fsm.act("READ2", - rx.eq(1), - rd.eq(1), - NextState("READ3") - ) - fsm.act("READ3", - rx.eq(1), - rd.eq(1), - NextState("READ4") - ) - fsm.act("READ4", - rx.eq(1), - NextState("READ5") - ) - fsm.act("READ5", - # 5ns D three-state to RD inactive - # 10ns A hold to RD inactive - rx.eq(1), - bus.ack.eq(1), - NextState("IDLE") - ) - fsm.act("GPIO", - bus.ack.eq(1), - bus_r_gpio.eq(1), - If(bus.we, gpio_load.eq(1)), - NextState("IDLE") - ) - fsm.act("FUD", - # 4ns FUD setup to SYNCLK - # 0ns FUD hold to SYNCLK - fud.eq(1), - bus.ack.eq(1), - NextState("IDLE") - ) def _test_gen(): - # Test external bus writes - yield TWrite(4, 2) - yield TWrite(5, 3) - yield - # Test external bus reads - yield TRead(14) - yield TRead(15) - yield - # Test FUD - yield TWrite(64, 0) - yield - # Test GPIO - yield TWrite(65, 0xff) - yield + # Test external bus writes + yield TWrite(4, 2) + yield TWrite(5, 3) + yield + # Test external bus reads + yield TRead(14) + yield TRead(15) + yield + # Test FUD + yield TWrite(64, 0) + yield + # Test GPIO + yield TWrite(65, 0xff) + yield + class _TestPads: - def __init__(self): - self.a = Signal(6) - self.d = Signal(8) - self.sel = Signal(5) - self.p = Signal(2) - self.fud_n = Signal() - self.wr_n = Signal() - self.rd_n = Signal() - self.rst_n = Signal() + def __init__(self): + self.a = Signal(6) + self.d = Signal(8) + self.sel = Signal(5) + self.p = Signal(2) + self.fud_n = Signal() + self.wr_n = Signal() + self.rd_n = Signal() + self.rst_n = Signal() + class _TB(Module): - def __init__(self): - pads = _TestPads() - self.submodules.dut = AD9858(pads) - self.submodules.initiator = wishbone.Initiator(_test_gen()) - self.submodules.interconnect = wishbone.InterconnectPointToPoint(self.initiator.bus, self.dut.bus) + def __init__(self): + pads = _TestPads() + self.submodules.dut = AD9858(pads) + self.submodules.initiator = wishbone.Initiator(_test_gen()) + self.submodules.interconnect = wishbone.InterconnectPointToPoint(self.initiator.bus, self.dut.bus) + if __name__ == "__main__": - run_simulation(_TB(), vcd_name="ad9858.vcd") + run_simulation(_TB(), vcd_name="ad9858.vcd") diff --git a/soc/artiqlib/rtio/core.py b/soc/artiqlib/rtio/core.py index 79ab4b77d..5b294b757 100644 --- a/soc/artiqlib/rtio/core.py +++ b/soc/artiqlib/rtio/core.py @@ -5,179 +5,182 @@ from migen.genlib.cdc import MultiReg from artiqlib.rtio.rbus import get_fine_ts_width + class _RTIOBankO(Module): - def __init__(self, rbus, counter_width, fine_ts_width, fifo_depth, counter_init): - self.sel = Signal(max=len(rbus)) - self.timestamp = Signal(counter_width+fine_ts_width) - self.value = Signal(2) - self.writable = Signal() - self.we = Signal() - self.underflow = Signal() - self.level = Signal(bits_for(fifo_depth)) + def __init__(self, rbus, counter_width, fine_ts_width, fifo_depth, counter_init): + self.sel = Signal(max=len(rbus)) + self.timestamp = Signal(counter_width+fine_ts_width) + self.value = Signal(2) + self.writable = Signal() + self.we = Signal() + self.underflow = Signal() + self.level = Signal(bits_for(fifo_depth)) - ### + # # # - counter = Signal(counter_width, reset=counter_init) - self.sync += [ - counter.eq(counter + 1), - If(self.we & self.writable, - If(self.timestamp[fine_ts_width:] < counter + 2, self.underflow.eq(1)) - ) - ] + counter = Signal(counter_width, reset=counter_init) + self.sync += [ + counter.eq(counter + 1), + If(self.we & self.writable, + If(self.timestamp[fine_ts_width:] < counter + 2, self.underflow.eq(1)) + ) + ] - fifos = [] - for n, chif in enumerate(rbus): - fifo = SyncFIFOBuffered([ - ("timestamp", counter_width+fine_ts_width), ("value", 2)], - fifo_depth) - self.submodules += fifo - fifos.append(fifo) + fifos = [] + for n, chif in enumerate(rbus): + fifo = SyncFIFOBuffered([ + ("timestamp", counter_width+fine_ts_width), ("value", 2)], + fifo_depth) + self.submodules += fifo + fifos.append(fifo) - # FIFO write - self.comb += [ - fifo.din.timestamp.eq(self.timestamp), - fifo.din.value.eq(self.value), - fifo.we.eq(self.we & (self.sel == n)) - ] + # FIFO write + self.comb += [ + fifo.din.timestamp.eq(self.timestamp), + fifo.din.value.eq(self.value), + fifo.we.eq(self.we & (self.sel == n)) + ] - # FIFO read - self.comb += [ - chif.o_stb.eq(fifo.readable & - (fifo.dout.timestamp[fine_ts_width:] == counter)), - chif.o_value.eq(fifo.dout.value), - fifo.re.eq(chif.o_stb) - ] - if fine_ts_width: - self.comb += chif.o_fine_ts.eq(fifo.dout.timestamp[:fine_ts_width]) + # FIFO read + self.comb += [ + chif.o_stb.eq(fifo.readable & + (fifo.dout.timestamp[fine_ts_width:] == counter)), + chif.o_value.eq(fifo.dout.value), + fifo.re.eq(chif.o_stb) + ] + if fine_ts_width: + self.comb += chif.o_fine_ts.eq(fifo.dout.timestamp[:fine_ts_width]) + + selfifo = Array(fifos)[self.sel] + self.comb += self.writable.eq(selfifo.writable), self.level.eq(selfifo.level) - selfifo = Array(fifos)[self.sel] - self.comb += self.writable.eq(selfifo.writable), self.level.eq(selfifo.level) class _RTIOBankI(Module): - def __init__(self, rbus, counter_width, fine_ts_width, fifo_depth): - self.sel = Signal(max=len(rbus)) - self.timestamp = Signal(counter_width+fine_ts_width) - self.value = Signal() - self.readable = Signal() - self.re = Signal() - self.overflow = Signal() + def __init__(self, rbus, counter_width, fine_ts_width, fifo_depth): + self.sel = Signal(max=len(rbus)) + self.timestamp = Signal(counter_width+fine_ts_width) + self.value = Signal() + self.readable = Signal() + self.re = Signal() + self.overflow = Signal() - ### + ### - counter = Signal(counter_width) - self.sync += counter.eq(counter + 1) + counter = Signal(counter_width) + self.sync += counter.eq(counter + 1) - timestamps = [] - values = [] - readables = [] - overflows = [] - for n, chif in enumerate(rbus): - if hasattr(chif, "oe"): - sensitivity = Signal(2) - self.sync += If(~chif.oe & chif.o_stb, - sensitivity.eq(chif.o_value)) + timestamps = [] + values = [] + readables = [] + overflows = [] + for n, chif in enumerate(rbus): + if hasattr(chif, "oe"): + sensitivity = Signal(2) + self.sync += If(~chif.oe & chif.o_stb, + sensitivity.eq(chif.o_value)) - fifo = SyncFIFOBuffered([ - ("timestamp", counter_width+fine_ts_width), ("value", 1)], - fifo_depth) - self.submodules += fifo - - # FIFO write - if fine_ts_width: - full_ts = Cat(chif.i_fine_ts, counter) - else: - full_ts = counter - self.comb += [ - fifo.din.timestamp.eq(full_ts), - fifo.din.value.eq(chif.i_value), - fifo.we.eq(~chif.oe & chif.i_stb & - ((chif.i_value & sensitivity[0]) | (~chif.i_value & sensitivity[1]))) - ] + fifo = SyncFIFOBuffered([ + ("timestamp", counter_width+fine_ts_width), ("value", 1)], + fifo_depth) + self.submodules += fifo + + # FIFO write + if fine_ts_width: + full_ts = Cat(chif.i_fine_ts, counter) + else: + full_ts = counter + self.comb += [ + fifo.din.timestamp.eq(full_ts), + fifo.din.value.eq(chif.i_value), + fifo.we.eq(~chif.oe & chif.i_stb & + ((chif.i_value & sensitivity[0]) | (~chif.i_value & sensitivity[1]))) + ] - # FIFO read - timestamps.append(fifo.dout.timestamp) - values.append(fifo.dout.value) - readables.append(fifo.readable) - self.comb += fifo.re.eq(self.re & (self.sel == n)) - - overflow = Signal() - self.sync += If(fifo.we & ~fifo.writable, overflow.eq(1)) - overflows.append(overflow) - else: - timestamps.append(0) - values.append(0) - readables.append(0) - overflows.append(0) + # FIFO read + timestamps.append(fifo.dout.timestamp) + values.append(fifo.dout.value) + readables.append(fifo.readable) + self.comb += fifo.re.eq(self.re & (self.sel == n)) + + overflow = Signal() + self.sync += If(fifo.we & ~fifo.writable, overflow.eq(1)) + overflows.append(overflow) + else: + timestamps.append(0) + values.append(0) + readables.append(0) + overflows.append(0) + + self.comb += [ + self.timestamp.eq(Array(timestamps)[self.sel]), + self.value.eq(Array(values)[self.sel]), + self.readable.eq(Array(readables)[self.sel]), + self.overflow.eq(Array(overflows)[self.sel]) + ] - self.comb += [ - self.timestamp.eq(Array(timestamps)[self.sel]), - self.value.eq(Array(values)[self.sel]), - self.readable.eq(Array(readables)[self.sel]), - self.overflow.eq(Array(overflows)[self.sel]) - ] class RTIO(Module, AutoCSR): - def __init__(self, phy, counter_width=32, ofifo_depth=8, ififo_depth=8): - fine_ts_width = get_fine_ts_width(phy.rbus) + def __init__(self, phy, counter_width=32, ofifo_depth=8, ififo_depth=8): + fine_ts_width = get_fine_ts_width(phy.rbus) - # Submodules - self.submodules.bank_o = InsertReset(_RTIOBankO(phy.rbus, - counter_width, fine_ts_width, ofifo_depth, - phy.loopback_latency)) - self.submodules.bank_i = InsertReset(_RTIOBankI(phy.rbus, - counter_width, fine_ts_width, ofifo_depth)) + # Submodules + self.submodules.bank_o = InsertReset(_RTIOBankO(phy.rbus, + counter_width, fine_ts_width, ofifo_depth, + phy.loopback_latency)) + self.submodules.bank_i = InsertReset(_RTIOBankI(phy.rbus, + counter_width, fine_ts_width, ofifo_depth)) - # CSRs - self._r_reset = CSRStorage(reset=1) - self._r_chan_sel = CSRStorage(flen(self.bank_o.sel)) - - self._r_oe = CSR() + # CSRs + self._r_reset = CSRStorage(reset=1) + self._r_chan_sel = CSRStorage(flen(self.bank_o.sel)) + + self._r_oe = CSR() - self._r_o_timestamp = CSRStorage(counter_width+fine_ts_width) - self._r_o_value = CSRStorage(2) - self._r_o_writable = CSRStatus() - self._r_o_we = CSR() - self._r_o_underflow = CSRStatus() - self._r_o_level = CSRStatus(bits_for(ofifo_depth)) + self._r_o_timestamp = CSRStorage(counter_width+fine_ts_width) + self._r_o_value = CSRStorage(2) + self._r_o_writable = CSRStatus() + self._r_o_we = CSR() + self._r_o_underflow = CSRStatus() + self._r_o_level = CSRStatus(bits_for(ofifo_depth)) - self._r_i_timestamp = CSRStatus(counter_width+fine_ts_width) - self._r_i_value = CSRStatus() - self._r_i_readable = CSRStatus() - self._r_i_re = CSR() - self._r_i_overflow = CSRStatus() + self._r_i_timestamp = CSRStatus(counter_width+fine_ts_width) + self._r_i_value = CSRStatus() + self._r_i_readable = CSRStatus() + self._r_i_re = CSR() + self._r_i_overflow = CSRStatus() - # OE - oes = [] - for n, chif in enumerate(phy.rbus): - if hasattr(chif, "oe"): - self.sync += \ - If(self._r_oe.re & (self._r_chan_sel.storage == n), - chif.oe.eq(self._r_oe.r) - ) - oes.append(chif.oe) - else: - oes.append(1) - self.comb += self._r_oe.w.eq(Array(oes)[self._r_chan_sel.storage]) + # OE + oes = [] + for n, chif in enumerate(phy.rbus): + if hasattr(chif, "oe"): + self.sync += \ + If(self._r_oe.re & (self._r_chan_sel.storage == n), + chif.oe.eq(self._r_oe.r) + ) + oes.append(chif.oe) + else: + oes.append(1) + self.comb += self._r_oe.w.eq(Array(oes)[self._r_chan_sel.storage]) - # Output/Gate - self.comb += [ - self.bank_o.reset.eq(self._r_reset.storage), - self.bank_o.sel.eq(self._r_chan_sel.storage), - self.bank_o.timestamp.eq(self._r_o_timestamp.storage), - self.bank_o.value.eq(self._r_o_value.storage), - self._r_o_writable.status.eq(self.bank_o.writable), - self.bank_o.we.eq(self._r_o_we.re), - self._r_o_underflow.status.eq(self.bank_o.underflow), - self._r_o_level.status.eq(self.bank_o.level) - ] + # Output/Gate + self.comb += [ + self.bank_o.reset.eq(self._r_reset.storage), + self.bank_o.sel.eq(self._r_chan_sel.storage), + self.bank_o.timestamp.eq(self._r_o_timestamp.storage), + self.bank_o.value.eq(self._r_o_value.storage), + self._r_o_writable.status.eq(self.bank_o.writable), + self.bank_o.we.eq(self._r_o_we.re), + self._r_o_underflow.status.eq(self.bank_o.underflow), + self._r_o_level.status.eq(self.bank_o.level) + ] - # Input - self.comb += [ - self.bank_i.reset.eq(self._r_reset.storage), - self.bank_i.sel.eq(self._r_chan_sel.storage), - self._r_i_timestamp.status.eq(self.bank_i.timestamp), - self._r_i_value.status.eq(self.bank_i.value), - self._r_i_readable.status.eq(self.bank_i.readable), - self.bank_i.re.eq(self._r_i_re.re), - self._r_i_overflow.status.eq(self.bank_i.overflow) - ] + # Input + self.comb += [ + self.bank_i.reset.eq(self._r_reset.storage), + self.bank_i.sel.eq(self._r_chan_sel.storage), + self._r_i_timestamp.status.eq(self.bank_i.timestamp), + self._r_i_value.status.eq(self.bank_i.value), + self._r_i_readable.status.eq(self.bank_i.readable), + self.bank_i.re.eq(self._r_i_re.re), + self._r_i_overflow.status.eq(self.bank_i.overflow) + ] diff --git a/soc/artiqlib/rtio/phy.py b/soc/artiqlib/rtio/phy.py index 1cdfa96ba..31b7e5870 100644 --- a/soc/artiqlib/rtio/phy.py +++ b/soc/artiqlib/rtio/phy.py @@ -3,27 +3,28 @@ from migen.genlib.cdc import MultiReg from artiqlib.rtio.rbus import create_rbus + class SimplePHY(Module): - def __init__(self, pads, output_only_pads=set()): - self.rbus = create_rbus(0, pads, output_only_pads) - self.loopback_latency = 3 + def __init__(self, pads, output_only_pads=set()): + self.rbus = create_rbus(0, pads, output_only_pads) + self.loopback_latency = 3 - ### + # # # - for pad, chif in zip(pads, self.rbus): - o_pad = Signal() - self.sync += If(chif.o_stb, o_pad.eq(chif.o_value)) - if pad in output_only_pads: - self.comb += pad.eq(o_pad) - else: - ts = TSTriple() - i_pad = Signal() - self.sync += ts.oe.eq(chif.oe) - self.comb += ts.o.eq(o_pad) - self.specials += MultiReg(ts.i, i_pad), \ - ts.get_tristate(pad) + for pad, chif in zip(pads, self.rbus): + o_pad = Signal() + self.sync += If(chif.o_stb, o_pad.eq(chif.o_value)) + if pad in output_only_pads: + self.comb += pad.eq(o_pad) + else: + ts = TSTriple() + i_pad = Signal() + self.sync += ts.oe.eq(chif.oe) + self.comb += ts.o.eq(o_pad) + self.specials += MultiReg(ts.i, i_pad), \ + ts.get_tristate(pad) - i_pad_d = Signal() - self.sync += i_pad_d.eq(i_pad) - self.comb += chif.i_stb.eq(i_pad ^ i_pad_d), \ - chif.i_value.eq(i_pad) + i_pad_d = Signal() + self.sync += i_pad_d.eq(i_pad) + self.comb += chif.i_stb.eq(i_pad ^ i_pad_d), \ + chif.i_value.eq(i_pad) diff --git a/soc/artiqlib/rtio/rbus.py b/soc/artiqlib/rtio/rbus.py index 47e739b7e..2c3008199 100644 --- a/soc/artiqlib/rtio/rbus.py +++ b/soc/artiqlib/rtio/rbus.py @@ -2,27 +2,27 @@ from migen.fhdl.std import * from migen.genlib.record import Record def create_rbus(fine_ts_bits, pads, output_only_pads): - rbus = [] - for pad in pads: - layout = [ - ("o_stb", 1), - ("o_value", 2) - ] - if fine_ts_bits: - layout.append(("o_fine_ts", fine_ts_bits)) - if pad not in output_only_pads: - layout += [ - ("oe", 1), - ("i_stb", 1), - ("i_value", 1) - ] - if fine_ts_bits: - layout.append(("i_fine_ts", fine_ts_bits)) - rbus.append(Record(layout)) - return rbus + rbus = [] + for pad in pads: + layout = [ + ("o_stb", 1), + ("o_value", 2) + ] + if fine_ts_bits: + layout.append(("o_fine_ts", fine_ts_bits)) + if pad not in output_only_pads: + layout += [ + ("oe", 1), + ("i_stb", 1), + ("i_value", 1) + ] + if fine_ts_bits: + layout.append(("i_fine_ts", fine_ts_bits)) + rbus.append(Record(layout)) + return rbus def get_fine_ts_width(rbus): - if hasattr(rbus[0], "o_fine_ts"): - return flen(rbus[0].o_fine_ts) - else: - return 0 + if hasattr(rbus[0], "o_fine_ts"): + return flen(rbus[0].o_fine_ts) + else: + return 0 diff --git a/soc/runtime/corecom_serial.c b/soc/runtime/corecom_serial.c index bb39b0ba9..e15b016ae 100644 --- a/soc/runtime/corecom_serial.c +++ b/soc/runtime/corecom_serial.c @@ -6,123 +6,123 @@ #include "corecom.h" enum { - MSGTYPE_REQUEST_IDENT = 0x01, - MSGTYPE_LOAD_KERNEL = 0x02, - MSGTYPE_KERNEL_FINISHED = 0x03, - MSGTYPE_RPC_REQUEST = 0x04, + MSGTYPE_REQUEST_IDENT = 0x01, + MSGTYPE_LOAD_KERNEL = 0x02, + MSGTYPE_KERNEL_FINISHED = 0x03, + MSGTYPE_RPC_REQUEST = 0x04, }; static int receive_int(void) { - unsigned int r; - int i; + unsigned int r; + int i; - r = 0; - for(i=0;i<4;i++) { - r <<= 8; - r |= (unsigned char)uart_read(); - } - return r; + r = 0; + for(i=0;i<4;i++) { + r <<= 8; + r |= (unsigned char)uart_read(); + } + return r; } static char receive_char(void) { - return uart_read(); + return uart_read(); } static void send_int(int x) { - int i; + int i; - for(i=0;i<4;i++) { - uart_write((x & 0xff000000) >> 24); - x <<= 8; - } + for(i=0;i<4;i++) { + uart_write((x & 0xff000000) >> 24); + x <<= 8; + } } static void send_sint(short int i) { - uart_write((i >> 8) & 0xff); - uart_write(i & 0xff); + uart_write((i >> 8) & 0xff); + uart_write(i & 0xff); } static void send_char(char c) { - uart_write(c); + uart_write(c); } static void receive_sync(void) { - char c; - int recognized; + char c; + int recognized; - recognized = 0; - while(recognized < 4) { - c = uart_read(); - if(c == 0x5a) - recognized++; - else - recognized = 0; - } + recognized = 0; + while(recognized < 4) { + c = uart_read(); + if(c == 0x5a) + recognized++; + else + recognized = 0; + } } static void send_sync(void) { - send_int(0x5a5a5a5a); + send_int(0x5a5a5a5a); } int ident_and_download_kernel(void *buffer, int maxlength) { - int length; - unsigned int crc; - int i; - char msgtype; - unsigned char *_buffer = buffer; + int length; + unsigned int crc; + int i; + char msgtype; + unsigned char *_buffer = buffer; - while(1) { - receive_sync(); - msgtype = receive_char(); - if(msgtype == MSGTYPE_REQUEST_IDENT) { - send_int(0x41524f52); /* "AROR" - ARTIQ runtime on OpenRISC */ - send_int(1000000000000LL/identifier_frequency_read()); /* RTIO clock period in picoseconds */ - } else if(msgtype == MSGTYPE_LOAD_KERNEL) { - length = receive_int(); - if(length > maxlength) { - send_char(0x4c); /* Incorrect length */ - return -1; - } - crc = receive_int(); - for(i=0;i maxlength) { + send_char(0x4c); /* Incorrect length */ + return -1; + } + crc = receive_int(); + for(i=0;i> 8) & 0xff); - DDS_WRITE(DDS_FTW2, (ftw >> 16) & 0xff); - DDS_WRITE(DDS_FTW3, (ftw >> 24) & 0xff); - DDS_WRITE(DDS_FUD, 0); + DDS_WRITE(DDS_GPIO, channel); + DDS_WRITE(DDS_FTW0, ftw & 0xff); + DDS_WRITE(DDS_FTW1, (ftw >> 8) & 0xff); + DDS_WRITE(DDS_FTW2, (ftw >> 16) & 0xff); + DDS_WRITE(DDS_FTW3, (ftw >> 24) & 0xff); + DDS_WRITE(DDS_FUD, 0); } diff --git a/soc/runtime/elf_loader.c b/soc/runtime/elf_loader.c index 718bec49a..d885b3529 100644 --- a/soc/runtime/elf_loader.c +++ b/soc/runtime/elf_loader.c @@ -6,27 +6,27 @@ #define EI_NIDENT 16 struct elf32_ehdr { - unsigned char ident[EI_NIDENT]; /* ident bytes */ - unsigned short type; /* file type */ - unsigned short machine; /* target machine */ - unsigned int version; /* file version */ - unsigned int entry; /* start address */ - unsigned int phoff; /* phdr file offset */ - unsigned int shoff; /* shdr file offset */ - unsigned int flags; /* file flags */ - unsigned short ehsize; /* sizeof ehdr */ - unsigned short phentsize; /* sizeof phdr */ - unsigned short phnum; /* number phdrs */ - unsigned short shentsize; /* sizeof shdr */ - unsigned short shnum; /* number shdrs */ - unsigned short shstrndx; /* shdr string index */ + unsigned char ident[EI_NIDENT]; /* ident bytes */ + unsigned short type; /* file type */ + unsigned short machine; /* target machine */ + unsigned int version; /* file version */ + unsigned int entry; /* start address */ + unsigned int phoff; /* phdr file offset */ + unsigned int shoff; /* shdr file offset */ + unsigned int flags; /* file flags */ + unsigned short ehsize; /* sizeof ehdr */ + unsigned short phentsize; /* sizeof phdr */ + unsigned short phnum; /* number phdrs */ + unsigned short shentsize; /* sizeof shdr */ + unsigned short shnum; /* number shdrs */ + unsigned short shstrndx; /* shdr string index */ } __attribute__((packed)); static const unsigned char elf_magic_header[] = { - 0x7f, 0x45, 0x4c, 0x46, /* 0x7f, 'E', 'L', 'F' */ - 0x01, /* Only 32-bit objects. */ - 0x02, /* Only big-endian. */ - 0x01, /* Only ELF version 1. */ + 0x7f, 0x45, 0x4c, 0x46, /* 0x7f, 'E', 'L', 'F' */ + 0x01, /* Only 32-bit objects. */ + 0x02, /* Only big-endian. */ + 0x01, /* Only ELF version 1. */ }; #define ET_NONE 0 /* Unknown type. */ @@ -38,26 +38,26 @@ static const unsigned char elf_magic_header[] = { #define EM_OR1K 0x005c struct elf32_shdr { - unsigned int name; /* section name */ - unsigned int type; /* SHT_... */ - unsigned int flags; /* SHF_... */ - unsigned int addr; /* virtual address */ - unsigned int offset; /* file offset */ - unsigned int size; /* section size */ - unsigned int link; /* misc info */ - unsigned int info; /* misc info */ - unsigned int addralign; /* memory alignment */ - unsigned int entsize; /* entry size if table */ + unsigned int name; /* section name */ + unsigned int type; /* SHT_... */ + unsigned int flags; /* SHF_... */ + unsigned int addr; /* virtual address */ + unsigned int offset; /* file offset */ + unsigned int size; /* section size */ + unsigned int link; /* misc info */ + unsigned int info; /* misc info */ + unsigned int addralign; /* memory alignment */ + unsigned int entsize; /* entry size if table */ } __attribute__((packed)); struct elf32_name { - char name[12]; + char name[12]; } __attribute__((packed)); struct elf32_rela { - unsigned int offset; /* Location to be relocated. */ - unsigned int info; /* Relocation type and symbol index. */ - int addend; /* Addend. */ + unsigned int offset; /* Location to be relocated. */ + unsigned int info; /* Relocation type and symbol index. */ + int addend; /* Addend. */ } __attribute__((packed)); #define ELF32_R_SYM(info) ((info) >> 8) @@ -66,151 +66,151 @@ struct elf32_rela { #define R_OR1K_INSN_REL_26 6 struct elf32_sym { - unsigned int name; /* String table index of name. */ - unsigned int value; /* Symbol value. */ - unsigned int size; /* Size of associated object. */ - unsigned char info; /* Type and binding information. */ - unsigned char other; /* Reserved (not used). */ - unsigned short shndx; /* Section index of symbol. */ + unsigned int name; /* String table index of name. */ + unsigned int value; /* Symbol value. */ + unsigned int size; /* Size of associated object. */ + unsigned char info; /* Type and binding information. */ + unsigned char other; /* Reserved (not used). */ + unsigned short shndx; /* Section index of symbol. */ } __attribute__((packed)); #define SANITIZE_OFFSET_SIZE(offset, size) \ - if(offset > 0x10000000) { \ - printf("Incorrect offset in ELF data"); \ - return 0; \ - } \ - if((offset + size) > elf_length) { \ - printf("Attempted to access past the end of ELF data"); \ - return 0; \ - } + if(offset > 0x10000000) { \ + printf("Incorrect offset in ELF data"); \ + return 0; \ + } \ + if((offset + size) > elf_length) { \ + printf("Attempted to access past the end of ELF data"); \ + return 0; \ + } #define GET_POINTER_SAFE(target, target_type, offset) \ - SANITIZE_OFFSET_SIZE(offset, sizeof(target_type)); \ - target = (target_type *)((char *)elf_data + offset) + SANITIZE_OFFSET_SIZE(offset, sizeof(target_type)); \ + target = (target_type *)((char *)elf_data + offset) void *find_symbol(const struct symbol *symbols, const char *name) { - int i; + int i; - i = 0; - while((symbols[i].name != NULL) && (strcmp(symbols[i].name, name) != 0)) - i++; - return symbols[i].target; + i = 0; + while((symbols[i].name != NULL) && (strcmp(symbols[i].name, name) != 0)) + i++; + return symbols[i].target; } static int fixup(void *dest, int dest_length, struct elf32_rela *rela, void *target) { - int type, offset; - unsigned int *_dest = dest; - unsigned int *_target = target; + int type, offset; + unsigned int *_dest = dest; + unsigned int *_target = target; - type = ELF32_R_TYPE(rela->info); - offset = rela->offset/4; - if(type == R_OR1K_INSN_REL_26) { - int val; + type = ELF32_R_TYPE(rela->info); + offset = rela->offset/4; + if(type == R_OR1K_INSN_REL_26) { + int val; - val = _target - (_dest + offset); - _dest[offset] = (_dest[offset] & 0xfc000000) | (val & 0x03ffffff); - } else - printf("Unsupported relocation type: %d\n", type); - return 1; + val = _target - (_dest + offset); + _dest[offset] = (_dest[offset] & 0xfc000000) | (val & 0x03ffffff); + } else + printf("Unsupported relocation type: %d\n", type); + return 1; } int load_elf(symbol_resolver resolver, void *elf_data, int elf_length, void *dest, int dest_length) { - struct elf32_ehdr *ehdr; - struct elf32_shdr *strtable; - unsigned int shdrptr; - int i; + struct elf32_ehdr *ehdr; + struct elf32_shdr *strtable; + unsigned int shdrptr; + int i; - unsigned int textoff, textsize; - unsigned int textrelaoff, textrelasize; - unsigned int symtaboff, symtabsize; - unsigned int strtaboff, strtabsize; + unsigned int textoff, textsize; + unsigned int textrelaoff, textrelasize; + unsigned int symtaboff, symtabsize; + unsigned int strtaboff, strtabsize; - /* validate ELF */ - GET_POINTER_SAFE(ehdr, struct elf32_ehdr, 0); - if(memcmp(ehdr->ident, elf_magic_header, sizeof(elf_magic_header)) != 0) { - printf("Incorrect ELF header\n"); - return 0; - } - if(ehdr->type != ET_REL) { - printf("ELF is not relocatable\n"); - return 0; - } - if(ehdr->machine != EM_OR1K) { - printf("ELF is for a different machine\n"); - return 0; - } + /* validate ELF */ + GET_POINTER_SAFE(ehdr, struct elf32_ehdr, 0); + if(memcmp(ehdr->ident, elf_magic_header, sizeof(elf_magic_header)) != 0) { + printf("Incorrect ELF header\n"); + return 0; + } + if(ehdr->type != ET_REL) { + printf("ELF is not relocatable\n"); + return 0; + } + if(ehdr->machine != EM_OR1K) { + printf("ELF is for a different machine\n"); + return 0; + } - /* extract section info */ - GET_POINTER_SAFE(strtable, struct elf32_shdr, ehdr->shoff + ehdr->shentsize*ehdr->shstrndx); - textoff = textsize = 0; - textrelaoff = textrelasize = 0; - symtaboff = symtabsize = 0; - strtaboff = strtabsize = 0; - shdrptr = ehdr->shoff; - for(i=0;ishnum;i++) { - struct elf32_shdr *shdr; - struct elf32_name *name; + /* extract section info */ + GET_POINTER_SAFE(strtable, struct elf32_shdr, ehdr->shoff + ehdr->shentsize*ehdr->shstrndx); + textoff = textsize = 0; + textrelaoff = textrelasize = 0; + symtaboff = symtabsize = 0; + strtaboff = strtabsize = 0; + shdrptr = ehdr->shoff; + for(i=0;ishnum;i++) { + struct elf32_shdr *shdr; + struct elf32_name *name; - GET_POINTER_SAFE(shdr, struct elf32_shdr, shdrptr); - GET_POINTER_SAFE(name, struct elf32_name, strtable->offset + shdr->name); - - if(strncmp(name->name, ".text", 5) == 0) { - textoff = shdr->offset; - textsize = shdr->size; - } else if(strncmp(name->name, ".rela.text", 10) == 0) { - textrelaoff = shdr->offset; - textrelasize = shdr->size; - } else if(strncmp(name->name, ".symtab", 7) == 0) { - symtaboff = shdr->offset; - symtabsize = shdr->size; - } else if(strncmp(name->name, ".strtab", 7) == 0) { - strtaboff = shdr->offset; - strtabsize = shdr->size; - } + GET_POINTER_SAFE(shdr, struct elf32_shdr, shdrptr); + GET_POINTER_SAFE(name, struct elf32_name, strtable->offset + shdr->name); + + if(strncmp(name->name, ".text", 5) == 0) { + textoff = shdr->offset; + textsize = shdr->size; + } else if(strncmp(name->name, ".rela.text", 10) == 0) { + textrelaoff = shdr->offset; + textrelasize = shdr->size; + } else if(strncmp(name->name, ".symtab", 7) == 0) { + symtaboff = shdr->offset; + symtabsize = shdr->size; + } else if(strncmp(name->name, ".strtab", 7) == 0) { + strtaboff = shdr->offset; + strtabsize = shdr->size; + } - shdrptr += ehdr->shentsize; - } - SANITIZE_OFFSET_SIZE(textoff, textsize); - SANITIZE_OFFSET_SIZE(textrelaoff, textrelasize); - SANITIZE_OFFSET_SIZE(symtaboff, symtabsize); - SANITIZE_OFFSET_SIZE(strtaboff, strtabsize); + shdrptr += ehdr->shentsize; + } + SANITIZE_OFFSET_SIZE(textoff, textsize); + SANITIZE_OFFSET_SIZE(textrelaoff, textrelasize); + SANITIZE_OFFSET_SIZE(symtaboff, symtabsize); + SANITIZE_OFFSET_SIZE(strtaboff, strtabsize); - /* load .text section */ - if(textsize > dest_length) { - printf(".text section is too large\n"); - return 0; - } - memcpy(dest, (char *)elf_data + textoff, textsize); + /* load .text section */ + if(textsize > dest_length) { + printf(".text section is too large\n"); + return 0; + } + memcpy(dest, (char *)elf_data + textoff, textsize); - /* process .text relocations */ - for(i=0;iinfo)); - if(sym->name != 0) { - void *target; + GET_POINTER_SAFE(rela, struct elf32_rela, textrelaoff + i); + GET_POINTER_SAFE(sym, struct elf32_sym, symtaboff + sizeof(struct elf32_sym)*ELF32_R_SYM(rela->info)); + if(sym->name != 0) { + void *target; - name = (char *)elf_data + strtaboff + sym->name; - target = resolver(name); - if(target == NULL) { - printf("Undefined symbol: %s\n", name); - return 0; - } - if(!fixup(dest, dest_length, rela, target)) - return 0; - } else { - printf("Unsupported relocation\n"); - return 0; - } - } + name = (char *)elf_data + strtaboff + sym->name; + target = resolver(name); + if(target == NULL) { + printf("Undefined symbol: %s\n", name); + return 0; + } + if(!fixup(dest, dest_length, rela, target)) + return 0; + } else { + printf("Unsupported relocation\n"); + return 0; + } + } - return 1; + return 1; } diff --git a/soc/runtime/elf_loader.h b/soc/runtime/elf_loader.h index 5b8431ff8..958a79be1 100644 --- a/soc/runtime/elf_loader.h +++ b/soc/runtime/elf_loader.h @@ -2,8 +2,8 @@ #define __ELF_LOADER_H struct symbol { - char *name; - void *target; + char *name; + void *target; }; void *find_symbol(const struct symbol *symbols, const char *name); diff --git a/soc/runtime/gpio.c b/soc/runtime/gpio.c index 3d7948ac4..de0b64bbb 100644 --- a/soc/runtime/gpio.c +++ b/soc/runtime/gpio.c @@ -4,11 +4,11 @@ void gpio_set(int channel, int value) { - static int csr_value; + static int csr_value; - if(value) - csr_value |= 1 << channel; - else - csr_value &= ~(1 << channel); - leds_out_write(csr_value); + if(value) + csr_value |= 1 << channel; + else + csr_value &= ~(1 << channel); + leds_out_write(csr_value); } diff --git a/soc/runtime/isr.c b/soc/runtime/isr.c index c49d31d8d..f42fa0694 100644 --- a/soc/runtime/isr.c +++ b/soc/runtime/isr.c @@ -5,10 +5,10 @@ void isr(void); void isr(void) { - unsigned int irqs; - - irqs = irq_pending() & irq_getmask(); - - if(irqs & (1 << UART_INTERRUPT)) - uart_isr(); + unsigned int irqs; + + irqs = irq_pending() & irq_getmask(); + + if(irqs & (1 << UART_INTERRUPT)) + uart_isr(); } diff --git a/soc/runtime/main.c b/soc/runtime/main.c index 6aca21f55..65878a13b 100644 --- a/soc/runtime/main.c +++ b/soc/runtime/main.c @@ -13,29 +13,29 @@ typedef void (*kernel_function)(void); int main(void) { - unsigned char kbuf[256*1024]; - unsigned char kcode[256*1024]; - kernel_function k = (kernel_function)kcode; - int length; + unsigned char kbuf[256*1024]; + unsigned char kcode[256*1024]; + kernel_function k = (kernel_function)kcode; + int length; - irq_setmask(0); - irq_setie(1); - uart_init(); - - puts("ARTIQ runtime built "__DATE__" "__TIME__"\n"); + irq_setmask(0); + irq_setie(1); + uart_init(); + + puts("ARTIQ runtime built "__DATE__" "__TIME__"\n"); - while(1) { - length = ident_and_download_kernel(kbuf, sizeof(kbuf)); - if(length > 0) { - if(load_elf(resolve_symbol, kbuf, length, kcode, sizeof(kcode))) { - rtio_init(); - dds_init(); - flush_cpu_icache(); - k(); - kernel_finished(); - } - } - } + while(1) { + length = ident_and_download_kernel(kbuf, sizeof(kbuf)); + if(length > 0) { + if(load_elf(resolve_symbol, kbuf, length, kcode, sizeof(kcode))) { + rtio_init(); + dds_init(); + flush_cpu_icache(); + k(); + kernel_finished(); + } + } + } - return 0; + return 0; } diff --git a/soc/runtime/rtio.c b/soc/runtime/rtio.c index f75aa2413..ff65b5b79 100644 --- a/soc/runtime/rtio.c +++ b/soc/runtime/rtio.c @@ -4,21 +4,21 @@ void rtio_init(void) { - rtio_reset_write(1); + rtio_reset_write(1); } void rtio_set(long long int timestamp, int channel, int value) { - rtio_reset_write(0); - rtio_chan_sel_write(channel); - rtio_o_timestamp_write(timestamp); - rtio_o_value_write(value); - while(!rtio_o_writable_read()); - rtio_o_we_write(1); + rtio_reset_write(0); + rtio_chan_sel_write(channel); + rtio_o_timestamp_write(timestamp); + rtio_o_value_write(value); + while(!rtio_o_writable_read()); + rtio_o_we_write(1); } void rtio_sync(int channel) { - rtio_chan_sel_write(channel); - while(rtio_o_level_read() != 0); + rtio_chan_sel_write(channel); + while(rtio_o_level_read() != 0); } diff --git a/soc/runtime/symbols.c b/soc/runtime/symbols.c index a00fe0d5c..5a794e955 100644 --- a/soc/runtime/symbols.c +++ b/soc/runtime/symbols.c @@ -8,34 +8,34 @@ #include "symbols.h" static const struct symbol syscalls[] = { - {"rpc", rpc}, - {"gpio_set", gpio_set}, - {"rtio_set", rtio_set}, - {"rtio_sync", rtio_sync}, - {"dds_program", dds_program}, - {NULL, NULL} + {"rpc", rpc}, + {"gpio_set", gpio_set}, + {"rtio_set", rtio_set}, + {"rtio_sync", rtio_sync}, + {"dds_program", dds_program}, + {NULL, NULL} }; static long long int gcd64(long long int a, long long int b) { - long long int c; + long long int c; - while(a) { - c = a; - a = b % a; - b = c; - } - return b; + while(a) { + c = a; + a = b % a; + b = c; + } + return b; } static const struct symbol arithmetic[] = { - {"__gcd64", gcd64}, - {NULL, NULL} + {"__gcd64", gcd64}, + {NULL, NULL} }; void *resolve_symbol(const char *name) { - if(strncmp(name, "__syscall_", 10) == 0) - return find_symbol(syscalls, name + 10); - return find_symbol(arithmetic, name); + if(strncmp(name, "__syscall_", 10) == 0) + return find_symbol(syscalls, name + 10); + return find_symbol(arithmetic, name); } diff --git a/soc/targets/artiq.py b/soc/targets/artiq.py index 46cf2c6c0..c93936e91 100644 --- a/soc/targets/artiq.py +++ b/soc/targets/artiq.py @@ -7,44 +7,44 @@ from targets.ppro import BaseSoC from artiqlib import rtio, ad9858 _tester_io = [ - ("user_led", 1, Pins("B:7"), IOStandard("LVTTL")), - ("ttl", 0, Pins("C:13"), IOStandard("LVTTL")), - ("ttl", 1, Pins("C:11"), IOStandard("LVTTL")), - ("ttl", 2, Pins("C:10"), IOStandard("LVTTL")), - ("ttl", 3, Pins("C:9"), IOStandard("LVTTL")), - ("ttl_tx_en", 0, Pins("A:9"), IOStandard("LVTTL")), - ("dds", 0, - Subsignal("a", Pins("A:5 B:10 A:6 B:9 A:7 B:8")), - Subsignal("d", Pins("A:12 B:3 A:13 B:2 A:14 B:1 A:15 B:0")), - Subsignal("sel", Pins("A:2 B:14 A:1 B:15 A:0")), - Subsignal("p", Pins("A:8 B:12")), - Subsignal("fud_n", Pins("B:11")), - Subsignal("wr_n", Pins("A:4")), - Subsignal("rd_n", Pins("B:13")), - Subsignal("rst_n", Pins("A:3")), - IOStandard("LVTTL")), + ("user_led", 1, Pins("B:7"), IOStandard("LVTTL")), + ("ttl", 0, Pins("C:13"), IOStandard("LVTTL")), + ("ttl", 1, Pins("C:11"), IOStandard("LVTTL")), + ("ttl", 2, Pins("C:10"), IOStandard("LVTTL")), + ("ttl", 3, Pins("C:9"), IOStandard("LVTTL")), + ("ttl_tx_en", 0, Pins("A:9"), IOStandard("LVTTL")), + ("dds", 0, + Subsignal("a", Pins("A:5 B:10 A:6 B:9 A:7 B:8")), + Subsignal("d", Pins("A:12 B:3 A:13 B:2 A:14 B:1 A:15 B:0")), + Subsignal("sel", Pins("A:2 B:14 A:1 B:15 A:0")), + Subsignal("p", Pins("A:8 B:12")), + Subsignal("fud_n", Pins("B:11")), + Subsignal("wr_n", Pins("A:4")), + Subsignal("rd_n", Pins("B:13")), + Subsignal("rst_n", Pins("A:3")), + IOStandard("LVTTL")), ] class ARTIQMiniSoC(BaseSoC): - csr_map = { - "rtio": 10 - } - csr_map.update(BaseSoC.csr_map) + csr_map = { + "rtio": 10 + } + csr_map.update(BaseSoC.csr_map) - def __init__(self, platform, cpu_type="or1k", **kwargs): - BaseSoC.__init__(self, platform, cpu_type=cpu_type, **kwargs) - platform.add_extension(_tester_io) + def __init__(self, platform, cpu_type="or1k", **kwargs): + BaseSoC.__init__(self, platform, cpu_type=cpu_type, **kwargs) + platform.add_extension(_tester_io) - self.submodules.leds = gpio.GPIOOut(Cat(platform.request("user_led", 0), - platform.request("user_led", 1))) + self.submodules.leds = gpio.GPIOOut(Cat(platform.request("user_led", 0), + platform.request("user_led", 1))) - self.comb += platform.request("ttl_tx_en").eq(1) - rtio_pads = [platform.request("ttl", i) for i in range(4)] - self.submodules.rtiophy = rtio.phy.SimplePHY(rtio_pads, - {rtio_pads[1], rtio_pads[2], rtio_pads[3]}) - self.submodules.rtio = rtio.RTIO(self.rtiophy) + self.comb += platform.request("ttl_tx_en").eq(1) + rtio_pads = [platform.request("ttl", i) for i in range(4)] + self.submodules.rtiophy = rtio.phy.SimplePHY(rtio_pads, + {rtio_pads[1], rtio_pads[2], rtio_pads[3]}) + self.submodules.rtio = rtio.RTIO(self.rtiophy) - self.submodules.dds = ad9858.AD9858(platform.request("dds")) - self.add_wb_slave(lambda a: a[26:29] == 3, self.dds.bus) + self.submodules.dds = ad9858.AD9858(platform.request("dds")) + self.add_wb_slave(lambda a: a[26:29] == 3, self.dds.bus) default_subtarget = ARTIQMiniSoC