compiler: monomorphize int64(round(x)) to not lose precision.

This applies to any expression with an indeterminate integer type
cast to int64(), not just round().
This commit is contained in:
whitequark 2016-12-02 15:02:44 +00:00
parent 696db32603
commit 68de724554
8 changed files with 40 additions and 21 deletions

View File

@ -48,6 +48,7 @@ class Module:
self.globals = src.globals
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
cast_monomorphizer = transforms.CastMonomorphizer(engine=self.engine)
inferencer = transforms.Inferencer(engine=self.engine)
monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
escape_validator = validators.EscapeValidator(engine=self.engine)
@ -63,6 +64,7 @@ class Module:
interleaver = transforms.Interleaver(engine=self.engine)
invariant_detection = analyses.InvariantDetection(engine=self.engine)
cast_monomorphizer.visit(src.typedtree)
int_monomorphizer.visit(src.typedtree)
inferencer.visit(src.typedtree)
monomorphism_validator.visit(src.typedtree)

View File

@ -1,7 +1,7 @@
import sys, fileinput, os
from pythonparser import source, diagnostic, algorithm, parse_buffer
from .. import prelude, types
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, CastMonomorphizer
from ..transforms import IODelayEstimator
class Printer(algorithm.Visitor):
@ -84,6 +84,7 @@ def main():
typed = ASTTypedRewriter(engine=engine, prelude=prelude.globals()).visit(parsed)
Inferencer(engine=engine).visit(typed)
if monomorphize:
CastMonomorphizer(engine=engine).visit(typed)
IntMonomorphizer(engine=engine).visit(typed)
Inferencer(engine=engine).visit(typed)
if iodelay:

View File

@ -1,6 +1,7 @@
from .asttyped_rewriter import ASTTypedRewriter
from .inferencer import Inferencer
from .int_monomorphizer import IntMonomorphizer
from .cast_monomorphizer import CastMonomorphizer
from .iodelay_estimator import IODelayEstimator
from .artiq_ir_generator import ARTIQIRGenerator
from .dead_code_eliminator import DeadCodeEliminator

View File

@ -0,0 +1,24 @@
"""
:class:`CastMonomorphizer` uses explicit casts to monomorphize
expressions of undetermined integer type to either 32 or 64 bits.
"""
from pythonparser import algorithm, diagnostic
from .. import types, builtins
class CastMonomorphizer(algorithm.Visitor):
def __init__(self, engine):
self.engine = engine
def visit_CallT(self, node):
self.generic_visit(node)
if (types.is_builtin(node.func.type, "int") or
types.is_builtin(node.func.type, "int32") or
types.is_builtin(node.func.type, "int64")):
typ = node.type.find()
if (not types.is_var(typ["width"]) and
builtins.is_int(node.args[0].type) and
types.is_var(node.args[0].type.find()["width"])):
node.args[0].type.unify(typ)

View File

@ -780,7 +780,6 @@ class Inferencer(algorithm.Visitor):
elif types.is_builtin(typ, "round"):
valid_forms = lambda: [
valid_form("round(x:float) -> numpy.int?"),
valid_form("round(x:float, width=?) -> numpy.int?")
]
self._unify(node.type, builtins.TInt(),
@ -791,19 +790,6 @@ class Inferencer(algorithm.Visitor):
self._unify(arg.type, builtins.TFloat(),
arg.loc, None)
elif len(node.args) == 1 and len(node.keywords) == 1 and \
builtins.is_numeric(node.args[0].type) and \
node.keywords[0].arg == 'width':
width = node.keywords[0].value
if not (isinstance(width, asttyped.NumT) and isinstance(width.n, int)):
diag = diagnostic.Diagnostic("error",
"the width argument of round() must be an integer literal", {},
node.keywords[0].loc)
self.engine.process(diag)
return
self._unify(node.type, builtins.TInt(types.TValue(width.n)),
node.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"):

View File

@ -30,6 +30,3 @@ len([])
# CHECK-L: round:<function round>(1.0:float):numpy.int?
round(1.0)
# CHECK-L: round:<function round>(1.0:float, width=64:numpy.int?):numpy.int64
round(1.0, width=64)

View File

@ -6,6 +6,3 @@ x = 1
y = int(1)
# CHECK-L: y: numpy.int32
z = round(1.0)
# CHECK-L: z: numpy.int32

View File

@ -0,0 +1,11 @@
# RUN: %python -m artiq.compiler.testbench.inferencer +mono %s >%t
# RUN: OutputCheck %s --file-to-check=%t
# CHECK-L: round:<function round>(1.0:float):numpy.int32
round(1.0)
# CHECK-L: round:<function round>(2.0:float):numpy.int32
int32(round(2.0))
# CHECK-L: round:<function round>(3.0:float):numpy.int64
int64(round(3.0))