forked from M-Labs/artiq
1
0
Fork 0

transforms/fold_constants: support decimal fractions

This commit is contained in:
Sebastien Bourdeauducq 2014-11-21 15:51:20 -08:00
parent 57cc6479c4
commit 1f92e19f2b
2 changed files with 22 additions and 11 deletions

View File

@ -1,5 +1,6 @@
import ast
import operator
from fractions import Fraction
from artiq.transforms.tools import *
from artiq.language.core import int64, round64
@ -135,14 +136,17 @@ class _ConstantFolder(ast.NodeTransformer):
"int": int,
"int64": int64,
"round": round,
"round64": round64
"round64": round64,
"Fraction": Fraction
}
if fn in constant_ops:
try:
arg = eval_constant(node.args[0])
except NotConstant:
return node
result = value_to_ast(constant_ops[fn](arg))
args = []
for arg in node.args:
try:
args.append(eval_constant(arg))
except NotConstant:
return node
result = value_to_ast(constant_ops[fn](*args))
return ast.copy_location(result, node)
else:
return node

View File

@ -1,6 +1,7 @@
import unittest
from operator import itemgetter
import os
from fractions import Fraction
from artiq import *
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
@ -40,13 +41,16 @@ class _Primes(AutoContext):
self.output_list.append(x)
class _Attributes(AutoContext):
class _Misc(AutoContext):
def build(self):
self.input = 84
@kernel
def run(self):
self.result = self.input//2
self.half_input = self.input//2
decimal_fraction = Fraction("1.2")
self.decimal_fraction_n = int(decimal_fraction.numerator)
self.decimal_fraction_d = int(decimal_fraction.denominator)
class _PulseLogger(AutoContext):
@ -144,12 +148,15 @@ class ExecutionCase(unittest.TestCase):
_run_on_host(_Primes, max=100, output_list=l_host)
self.assertEqual(l_device, l_host)
def test_attributes(self):
def test_misc(self):
with comm_serial.Comm() as comm:
coredev = core.Core(comm)
uut = _Attributes(core=coredev)
uut = _Misc(core=coredev)
uut.run()
self.assertEqual(uut.result, 42)
self.assertEqual(uut.half_input, 42)
self.assertEqual(Fraction(uut.decimal_fraction_n,
uut.decimal_fraction_d),
Fraction("1.2"))
def test_pulses(self):
l_device, l_host = [], []