transforms/fold_constants: support decimal fractions

This commit is contained in:
Sebastien Bourdeauducq 2014-11-21 15:51:20 -08:00
parent 57cc6479c4
commit 1f92e19f2b
2 changed files with 22 additions and 11 deletions

View File

@ -1,5 +1,6 @@
import ast import ast
import operator import operator
from fractions import Fraction
from artiq.transforms.tools import * from artiq.transforms.tools import *
from artiq.language.core import int64, round64 from artiq.language.core import int64, round64
@ -135,14 +136,17 @@ class _ConstantFolder(ast.NodeTransformer):
"int": int, "int": int,
"int64": int64, "int64": int64,
"round": round, "round": round,
"round64": round64 "round64": round64,
"Fraction": Fraction
} }
if fn in constant_ops: if fn in constant_ops:
try: args = []
arg = eval_constant(node.args[0]) for arg in node.args:
except NotConstant: try:
return node args.append(eval_constant(arg))
result = value_to_ast(constant_ops[fn](arg)) except NotConstant:
return node
result = value_to_ast(constant_ops[fn](*args))
return ast.copy_location(result, node) return ast.copy_location(result, node)
else: else:
return node return node

View File

@ -1,6 +1,7 @@
import unittest import unittest
from operator import itemgetter from operator import itemgetter
import os import os
from fractions import Fraction
from artiq import * from artiq import *
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
@ -40,13 +41,16 @@ class _Primes(AutoContext):
self.output_list.append(x) self.output_list.append(x)
class _Attributes(AutoContext): class _Misc(AutoContext):
def build(self): def build(self):
self.input = 84 self.input = 84
@kernel @kernel
def run(self): def run(self):
self.result = self.input//2 self.half_input = self.input//2
decimal_fraction = Fraction("1.2")
self.decimal_fraction_n = int(decimal_fraction.numerator)
self.decimal_fraction_d = int(decimal_fraction.denominator)
class _PulseLogger(AutoContext): class _PulseLogger(AutoContext):
@ -144,12 +148,15 @@ class ExecutionCase(unittest.TestCase):
_run_on_host(_Primes, max=100, output_list=l_host) _run_on_host(_Primes, max=100, output_list=l_host)
self.assertEqual(l_device, l_host) self.assertEqual(l_device, l_host)
def test_attributes(self): def test_misc(self):
with comm_serial.Comm() as comm: with comm_serial.Comm() as comm:
coredev = core.Core(comm) coredev = core.Core(comm)
uut = _Attributes(core=coredev) uut = _Misc(core=coredev)
uut.run() uut.run()
self.assertEqual(uut.result, 42) self.assertEqual(uut.half_input, 42)
self.assertEqual(Fraction(uut.decimal_fraction_n,
uut.decimal_fraction_d),
Fraction("1.2"))
def test_pulses(self): def test_pulses(self):
l_device, l_host = [], [] l_device, l_host = [], []