forked from M-Labs/artiq
transforms/fold_constants: support decimal fractions
This commit is contained in:
parent
57cc6479c4
commit
1f92e19f2b
|
@ -1,5 +1,6 @@
|
|||
import ast
|
||||
import operator
|
||||
from fractions import Fraction
|
||||
|
||||
from artiq.transforms.tools import *
|
||||
from artiq.language.core import int64, round64
|
||||
|
@ -135,14 +136,17 @@ class _ConstantFolder(ast.NodeTransformer):
|
|||
"int": int,
|
||||
"int64": int64,
|
||||
"round": round,
|
||||
"round64": round64
|
||||
"round64": round64,
|
||||
"Fraction": Fraction
|
||||
}
|
||||
if fn in constant_ops:
|
||||
args = []
|
||||
for arg in node.args:
|
||||
try:
|
||||
arg = eval_constant(node.args[0])
|
||||
args.append(eval_constant(arg))
|
||||
except NotConstant:
|
||||
return node
|
||||
result = value_to_ast(constant_ops[fn](arg))
|
||||
result = value_to_ast(constant_ops[fn](*args))
|
||||
return ast.copy_location(result, node)
|
||||
else:
|
||||
return node
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import unittest
|
||||
from operator import itemgetter
|
||||
import os
|
||||
from fractions import Fraction
|
||||
|
||||
from artiq import *
|
||||
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
|
||||
|
@ -40,13 +41,16 @@ class _Primes(AutoContext):
|
|||
self.output_list.append(x)
|
||||
|
||||
|
||||
class _Attributes(AutoContext):
|
||||
class _Misc(AutoContext):
|
||||
def build(self):
|
||||
self.input = 84
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
self.result = self.input//2
|
||||
self.half_input = self.input//2
|
||||
decimal_fraction = Fraction("1.2")
|
||||
self.decimal_fraction_n = int(decimal_fraction.numerator)
|
||||
self.decimal_fraction_d = int(decimal_fraction.denominator)
|
||||
|
||||
|
||||
class _PulseLogger(AutoContext):
|
||||
|
@ -144,12 +148,15 @@ class ExecutionCase(unittest.TestCase):
|
|||
_run_on_host(_Primes, max=100, output_list=l_host)
|
||||
self.assertEqual(l_device, l_host)
|
||||
|
||||
def test_attributes(self):
|
||||
def test_misc(self):
|
||||
with comm_serial.Comm() as comm:
|
||||
coredev = core.Core(comm)
|
||||
uut = _Attributes(core=coredev)
|
||||
uut = _Misc(core=coredev)
|
||||
uut.run()
|
||||
self.assertEqual(uut.result, 42)
|
||||
self.assertEqual(uut.half_input, 42)
|
||||
self.assertEqual(Fraction(uut.decimal_fraction_n,
|
||||
uut.decimal_fraction_d),
|
||||
Fraction("1.2"))
|
||||
|
||||
def test_pulses(self):
|
||||
l_device, l_host = [], []
|
||||
|
|
Loading…
Reference in New Issue