compiler: monomorphize int64(round(x)) to not lose precision.

This applies to any expression with an indeterminate integer type
cast to int64(), not just round().
This commit is contained in:
whitequark 2016-12-02 15:02:44 +00:00
parent 696db32603
commit 68de724554
8 changed files with 40 additions and 21 deletions

View File

@ -48,6 +48,7 @@ class Module:
self.globals = src.globals self.globals = src.globals
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine) int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
cast_monomorphizer = transforms.CastMonomorphizer(engine=self.engine)
inferencer = transforms.Inferencer(engine=self.engine) inferencer = transforms.Inferencer(engine=self.engine)
monomorphism_validator = validators.MonomorphismValidator(engine=self.engine) monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
escape_validator = validators.EscapeValidator(engine=self.engine) escape_validator = validators.EscapeValidator(engine=self.engine)
@ -63,6 +64,7 @@ class Module:
interleaver = transforms.Interleaver(engine=self.engine) interleaver = transforms.Interleaver(engine=self.engine)
invariant_detection = analyses.InvariantDetection(engine=self.engine) invariant_detection = analyses.InvariantDetection(engine=self.engine)
cast_monomorphizer.visit(src.typedtree)
int_monomorphizer.visit(src.typedtree) int_monomorphizer.visit(src.typedtree)
inferencer.visit(src.typedtree) inferencer.visit(src.typedtree)
monomorphism_validator.visit(src.typedtree) monomorphism_validator.visit(src.typedtree)

View File

@ -1,7 +1,7 @@
import sys, fileinput, os import sys, fileinput, os
from pythonparser import source, diagnostic, algorithm, parse_buffer from pythonparser import source, diagnostic, algorithm, parse_buffer
from .. import prelude, types from .. import prelude, types
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, CastMonomorphizer
from ..transforms import IODelayEstimator from ..transforms import IODelayEstimator
class Printer(algorithm.Visitor): class Printer(algorithm.Visitor):
@ -84,6 +84,7 @@ def main():
typed = ASTTypedRewriter(engine=engine, prelude=prelude.globals()).visit(parsed) typed = ASTTypedRewriter(engine=engine, prelude=prelude.globals()).visit(parsed)
Inferencer(engine=engine).visit(typed) Inferencer(engine=engine).visit(typed)
if monomorphize: if monomorphize:
CastMonomorphizer(engine=engine).visit(typed)
IntMonomorphizer(engine=engine).visit(typed) IntMonomorphizer(engine=engine).visit(typed)
Inferencer(engine=engine).visit(typed) Inferencer(engine=engine).visit(typed)
if iodelay: if iodelay:

View File

@ -1,6 +1,7 @@
from .asttyped_rewriter import ASTTypedRewriter from .asttyped_rewriter import ASTTypedRewriter
from .inferencer import Inferencer from .inferencer import Inferencer
from .int_monomorphizer import IntMonomorphizer from .int_monomorphizer import IntMonomorphizer
from .cast_monomorphizer import CastMonomorphizer
from .iodelay_estimator import IODelayEstimator from .iodelay_estimator import IODelayEstimator
from .artiq_ir_generator import ARTIQIRGenerator from .artiq_ir_generator import ARTIQIRGenerator
from .dead_code_eliminator import DeadCodeEliminator from .dead_code_eliminator import DeadCodeEliminator

View File

@ -0,0 +1,24 @@
"""
:class:`CastMonomorphizer` uses explicit casts to monomorphize
expressions of undetermined integer type to either 32 or 64 bits.
"""
from pythonparser import algorithm, diagnostic
from .. import types, builtins
class CastMonomorphizer(algorithm.Visitor):
def __init__(self, engine):
self.engine = engine
def visit_CallT(self, node):
self.generic_visit(node)
if (types.is_builtin(node.func.type, "int") or
types.is_builtin(node.func.type, "int32") or
types.is_builtin(node.func.type, "int64")):
typ = node.type.find()
if (not types.is_var(typ["width"]) and
builtins.is_int(node.args[0].type) and
types.is_var(node.args[0].type.find()["width"])):
node.args[0].type.unify(typ)

View File

@ -780,7 +780,6 @@ class Inferencer(algorithm.Visitor):
elif types.is_builtin(typ, "round"): elif types.is_builtin(typ, "round"):
valid_forms = lambda: [ valid_forms = lambda: [
valid_form("round(x:float) -> numpy.int?"), valid_form("round(x:float) -> numpy.int?"),
valid_form("round(x:float, width=?) -> numpy.int?")
] ]
self._unify(node.type, builtins.TInt(), self._unify(node.type, builtins.TInt(),
@ -791,19 +790,6 @@ class Inferencer(algorithm.Visitor):
self._unify(arg.type, builtins.TFloat(), self._unify(arg.type, builtins.TFloat(),
arg.loc, None) arg.loc, None)
elif len(node.args) == 1 and len(node.keywords) == 1 and \
builtins.is_numeric(node.args[0].type) and \
node.keywords[0].arg == 'width':
width = node.keywords[0].value
if not (isinstance(width, asttyped.NumT) and isinstance(width.n, int)):
diag = diagnostic.Diagnostic("error",
"the width argument of round() must be an integer literal", {},
node.keywords[0].loc)
self.engine.process(diag)
return
self._unify(node.type, builtins.TInt(types.TValue(width.n)),
node.loc, None)
else: else:
diagnose(valid_forms()) diagnose(valid_forms())
elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"): elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"):

View File

@ -30,6 +30,3 @@ len([])
# CHECK-L: round:<function round>(1.0:float):numpy.int? # CHECK-L: round:<function round>(1.0:float):numpy.int?
round(1.0) round(1.0)
# CHECK-L: round:<function round>(1.0:float, width=64:numpy.int?):numpy.int64
round(1.0, width=64)

View File

@ -6,6 +6,3 @@ x = 1
y = int(1) y = int(1)
# CHECK-L: y: numpy.int32 # CHECK-L: y: numpy.int32
z = round(1.0)
# CHECK-L: z: numpy.int32

View File

@ -0,0 +1,11 @@
# RUN: %python -m artiq.compiler.testbench.inferencer +mono %s >%t
# RUN: OutputCheck %s --file-to-check=%t
# CHECK-L: round:<function round>(1.0:float):numpy.int32
round(1.0)
# CHECK-L: round:<function round>(2.0:float):numpy.int32
int32(round(2.0))
# CHECK-L: round:<function round>(3.0:float):numpy.int64
int64(round(3.0))