forked from M-Labs/artiq
transforms/fold_constants: support decimal fractions
This commit is contained in:
parent
57cc6479c4
commit
1f92e19f2b
|
@ -1,5 +1,6 @@
|
||||||
import ast
|
import ast
|
||||||
import operator
|
import operator
|
||||||
|
from fractions import Fraction
|
||||||
|
|
||||||
from artiq.transforms.tools import *
|
from artiq.transforms.tools import *
|
||||||
from artiq.language.core import int64, round64
|
from artiq.language.core import int64, round64
|
||||||
|
@ -135,14 +136,17 @@ class _ConstantFolder(ast.NodeTransformer):
|
||||||
"int": int,
|
"int": int,
|
||||||
"int64": int64,
|
"int64": int64,
|
||||||
"round": round,
|
"round": round,
|
||||||
"round64": round64
|
"round64": round64,
|
||||||
|
"Fraction": Fraction
|
||||||
}
|
}
|
||||||
if fn in constant_ops:
|
if fn in constant_ops:
|
||||||
try:
|
args = []
|
||||||
arg = eval_constant(node.args[0])
|
for arg in node.args:
|
||||||
except NotConstant:
|
try:
|
||||||
return node
|
args.append(eval_constant(arg))
|
||||||
result = value_to_ast(constant_ops[fn](arg))
|
except NotConstant:
|
||||||
|
return node
|
||||||
|
result = value_to_ast(constant_ops[fn](*args))
|
||||||
return ast.copy_location(result, node)
|
return ast.copy_location(result, node)
|
||||||
else:
|
else:
|
||||||
return node
|
return node
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
import os
|
import os
|
||||||
|
from fractions import Fraction
|
||||||
|
|
||||||
from artiq import *
|
from artiq import *
|
||||||
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
|
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
|
||||||
|
@ -40,13 +41,16 @@ class _Primes(AutoContext):
|
||||||
self.output_list.append(x)
|
self.output_list.append(x)
|
||||||
|
|
||||||
|
|
||||||
class _Attributes(AutoContext):
|
class _Misc(AutoContext):
|
||||||
def build(self):
|
def build(self):
|
||||||
self.input = 84
|
self.input = 84
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
def run(self):
|
def run(self):
|
||||||
self.result = self.input//2
|
self.half_input = self.input//2
|
||||||
|
decimal_fraction = Fraction("1.2")
|
||||||
|
self.decimal_fraction_n = int(decimal_fraction.numerator)
|
||||||
|
self.decimal_fraction_d = int(decimal_fraction.denominator)
|
||||||
|
|
||||||
|
|
||||||
class _PulseLogger(AutoContext):
|
class _PulseLogger(AutoContext):
|
||||||
|
@ -144,12 +148,15 @@ class ExecutionCase(unittest.TestCase):
|
||||||
_run_on_host(_Primes, max=100, output_list=l_host)
|
_run_on_host(_Primes, max=100, output_list=l_host)
|
||||||
self.assertEqual(l_device, l_host)
|
self.assertEqual(l_device, l_host)
|
||||||
|
|
||||||
def test_attributes(self):
|
def test_misc(self):
|
||||||
with comm_serial.Comm() as comm:
|
with comm_serial.Comm() as comm:
|
||||||
coredev = core.Core(comm)
|
coredev = core.Core(comm)
|
||||||
uut = _Attributes(core=coredev)
|
uut = _Misc(core=coredev)
|
||||||
uut.run()
|
uut.run()
|
||||||
self.assertEqual(uut.result, 42)
|
self.assertEqual(uut.half_input, 42)
|
||||||
|
self.assertEqual(Fraction(uut.decimal_fraction_n,
|
||||||
|
uut.decimal_fraction_d),
|
||||||
|
Fraction("1.2"))
|
||||||
|
|
||||||
def test_pulses(self):
|
def test_pulses(self):
|
||||||
l_device, l_host = [], []
|
l_device, l_host = [], []
|
||||||
|
|
Loading…
Reference in New Issue