forked from M-Labs/artiq
compiler: monomorphize int64(round(x)) to not lose precision.
This applies to any expression with an indeterminate integer type cast to int64(), not just round().
This commit is contained in:
parent
696db32603
commit
68de724554
|
@ -48,6 +48,7 @@ class Module:
|
|||
self.globals = src.globals
|
||||
|
||||
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
|
||||
cast_monomorphizer = transforms.CastMonomorphizer(engine=self.engine)
|
||||
inferencer = transforms.Inferencer(engine=self.engine)
|
||||
monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
|
||||
escape_validator = validators.EscapeValidator(engine=self.engine)
|
||||
|
@ -63,6 +64,7 @@ class Module:
|
|||
interleaver = transforms.Interleaver(engine=self.engine)
|
||||
invariant_detection = analyses.InvariantDetection(engine=self.engine)
|
||||
|
||||
cast_monomorphizer.visit(src.typedtree)
|
||||
int_monomorphizer.visit(src.typedtree)
|
||||
inferencer.visit(src.typedtree)
|
||||
monomorphism_validator.visit(src.typedtree)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import sys, fileinput, os
|
||||
from pythonparser import source, diagnostic, algorithm, parse_buffer
|
||||
from .. import prelude, types
|
||||
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
|
||||
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, CastMonomorphizer
|
||||
from ..transforms import IODelayEstimator
|
||||
|
||||
class Printer(algorithm.Visitor):
|
||||
|
@ -84,6 +84,7 @@ def main():
|
|||
typed = ASTTypedRewriter(engine=engine, prelude=prelude.globals()).visit(parsed)
|
||||
Inferencer(engine=engine).visit(typed)
|
||||
if monomorphize:
|
||||
CastMonomorphizer(engine=engine).visit(typed)
|
||||
IntMonomorphizer(engine=engine).visit(typed)
|
||||
Inferencer(engine=engine).visit(typed)
|
||||
if iodelay:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .asttyped_rewriter import ASTTypedRewriter
|
||||
from .inferencer import Inferencer
|
||||
from .int_monomorphizer import IntMonomorphizer
|
||||
from .cast_monomorphizer import CastMonomorphizer
|
||||
from .iodelay_estimator import IODelayEstimator
|
||||
from .artiq_ir_generator import ARTIQIRGenerator
|
||||
from .dead_code_eliminator import DeadCodeEliminator
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
"""
|
||||
:class:`CastMonomorphizer` uses explicit casts to monomorphize
|
||||
expressions of undetermined integer type to either 32 or 64 bits.
|
||||
"""
|
||||
|
||||
from pythonparser import algorithm, diagnostic
|
||||
from .. import types, builtins
|
||||
|
||||
class CastMonomorphizer(algorithm.Visitor):
|
||||
def __init__(self, engine):
|
||||
self.engine = engine
|
||||
|
||||
def visit_CallT(self, node):
|
||||
self.generic_visit(node)
|
||||
|
||||
if (types.is_builtin(node.func.type, "int") or
|
||||
types.is_builtin(node.func.type, "int32") or
|
||||
types.is_builtin(node.func.type, "int64")):
|
||||
typ = node.type.find()
|
||||
if (not types.is_var(typ["width"]) and
|
||||
builtins.is_int(node.args[0].type) and
|
||||
types.is_var(node.args[0].type.find()["width"])):
|
||||
node.args[0].type.unify(typ)
|
||||
|
|
@ -780,7 +780,6 @@ class Inferencer(algorithm.Visitor):
|
|||
elif types.is_builtin(typ, "round"):
|
||||
valid_forms = lambda: [
|
||||
valid_form("round(x:float) -> numpy.int?"),
|
||||
valid_form("round(x:float, width=?) -> numpy.int?")
|
||||
]
|
||||
|
||||
self._unify(node.type, builtins.TInt(),
|
||||
|
@ -791,19 +790,6 @@ class Inferencer(algorithm.Visitor):
|
|||
|
||||
self._unify(arg.type, builtins.TFloat(),
|
||||
arg.loc, None)
|
||||
elif len(node.args) == 1 and len(node.keywords) == 1 and \
|
||||
builtins.is_numeric(node.args[0].type) and \
|
||||
node.keywords[0].arg == 'width':
|
||||
width = node.keywords[0].value
|
||||
if not (isinstance(width, asttyped.NumT) and isinstance(width.n, int)):
|
||||
diag = diagnostic.Diagnostic("error",
|
||||
"the width argument of round() must be an integer literal", {},
|
||||
node.keywords[0].loc)
|
||||
self.engine.process(diag)
|
||||
return
|
||||
|
||||
self._unify(node.type, builtins.TInt(types.TValue(width.n)),
|
||||
node.loc, None)
|
||||
else:
|
||||
diagnose(valid_forms())
|
||||
elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"):
|
||||
|
|
|
@ -30,6 +30,3 @@ len([])
|
|||
|
||||
# CHECK-L: round:<function round>(1.0:float):numpy.int?
|
||||
round(1.0)
|
||||
|
||||
# CHECK-L: round:<function round>(1.0:float, width=64:numpy.int?):numpy.int64
|
||||
round(1.0, width=64)
|
||||
|
|
|
@ -6,6 +6,3 @@ x = 1
|
|||
|
||||
y = int(1)
|
||||
# CHECK-L: y: numpy.int32
|
||||
|
||||
z = round(1.0)
|
||||
# CHECK-L: z: numpy.int32
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
# RUN: %python -m artiq.compiler.testbench.inferencer +mono %s >%t
|
||||
# RUN: OutputCheck %s --file-to-check=%t
|
||||
|
||||
# CHECK-L: round:<function round>(1.0:float):numpy.int32
|
||||
round(1.0)
|
||||
|
||||
# CHECK-L: round:<function round>(2.0:float):numpy.int32
|
||||
int32(round(2.0))
|
||||
|
||||
# CHECK-L: round:<function round>(3.0:float):numpy.int64
|
||||
int64(round(3.0))
|
Loading…
Reference in New Issue