From 1f92e19f2bf9b0837d46d8599cac101971c372f0 Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Fri, 21 Nov 2014 15:51:20 -0800 Subject: [PATCH] transforms/fold_constants: support decimal fractions --- artiq/transforms/fold_constants.py | 16 ++++++++++------ test/full_stack.py | 17 ++++++++++++----- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/artiq/transforms/fold_constants.py b/artiq/transforms/fold_constants.py index ff9aac9ca..402fc243b 100644 --- a/artiq/transforms/fold_constants.py +++ b/artiq/transforms/fold_constants.py @@ -1,5 +1,6 @@ import ast import operator +from fractions import Fraction from artiq.transforms.tools import * from artiq.language.core import int64, round64 @@ -135,14 +136,17 @@ class _ConstantFolder(ast.NodeTransformer): "int": int, "int64": int64, "round": round, - "round64": round64 + "round64": round64, + "Fraction": Fraction } if fn in constant_ops: - try: - arg = eval_constant(node.args[0]) - except NotConstant: - return node - result = value_to_ast(constant_ops[fn](arg)) + args = [] + for arg in node.args: + try: + args.append(eval_constant(arg)) + except NotConstant: + return node + result = value_to_ast(constant_ops[fn](*args)) return ast.copy_location(result, node) else: return node diff --git a/test/full_stack.py b/test/full_stack.py index 73203b636..d8697b0b9 100644 --- a/test/full_stack.py +++ b/test/full_stack.py @@ -1,6 +1,7 @@ import unittest from operator import itemgetter import os +from fractions import Fraction from artiq import * from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio @@ -40,13 +41,16 @@ class _Primes(AutoContext): self.output_list.append(x) -class _Attributes(AutoContext): +class _Misc(AutoContext): def build(self): self.input = 84 @kernel def run(self): - self.result = self.input//2 + self.half_input = self.input//2 + decimal_fraction = Fraction("1.2") + self.decimal_fraction_n = int(decimal_fraction.numerator) + self.decimal_fraction_d = int(decimal_fraction.denominator) class _PulseLogger(AutoContext): @@ -144,12 +148,15 @@ class ExecutionCase(unittest.TestCase): _run_on_host(_Primes, max=100, output_list=l_host) self.assertEqual(l_device, l_host) - def test_attributes(self): + def test_misc(self): with comm_serial.Comm() as comm: coredev = core.Core(comm) - uut = _Attributes(core=coredev) + uut = _Misc(core=coredev) uut.run() - self.assertEqual(uut.result, 42) + self.assertEqual(uut.half_input, 42) + self.assertEqual(Fraction(uut.decimal_fraction_n, + uut.decimal_fraction_d), + Fraction("1.2")) def test_pulses(self): l_device, l_host = [], []