forked from M-Labs/artiq
compiler: monomorphize int64(round(x)) to not lose precision.
This applies to any expression with an indeterminate integer type cast to int64(), not just round().
This commit is contained in:
parent
696db32603
commit
68de724554
|
@ -48,6 +48,7 @@ class Module:
|
||||||
self.globals = src.globals
|
self.globals = src.globals
|
||||||
|
|
||||||
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
|
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
|
||||||
|
cast_monomorphizer = transforms.CastMonomorphizer(engine=self.engine)
|
||||||
inferencer = transforms.Inferencer(engine=self.engine)
|
inferencer = transforms.Inferencer(engine=self.engine)
|
||||||
monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
|
monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
|
||||||
escape_validator = validators.EscapeValidator(engine=self.engine)
|
escape_validator = validators.EscapeValidator(engine=self.engine)
|
||||||
|
@ -63,6 +64,7 @@ class Module:
|
||||||
interleaver = transforms.Interleaver(engine=self.engine)
|
interleaver = transforms.Interleaver(engine=self.engine)
|
||||||
invariant_detection = analyses.InvariantDetection(engine=self.engine)
|
invariant_detection = analyses.InvariantDetection(engine=self.engine)
|
||||||
|
|
||||||
|
cast_monomorphizer.visit(src.typedtree)
|
||||||
int_monomorphizer.visit(src.typedtree)
|
int_monomorphizer.visit(src.typedtree)
|
||||||
inferencer.visit(src.typedtree)
|
inferencer.visit(src.typedtree)
|
||||||
monomorphism_validator.visit(src.typedtree)
|
monomorphism_validator.visit(src.typedtree)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import sys, fileinput, os
|
import sys, fileinput, os
|
||||||
from pythonparser import source, diagnostic, algorithm, parse_buffer
|
from pythonparser import source, diagnostic, algorithm, parse_buffer
|
||||||
from .. import prelude, types
|
from .. import prelude, types
|
||||||
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
|
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, CastMonomorphizer
|
||||||
from ..transforms import IODelayEstimator
|
from ..transforms import IODelayEstimator
|
||||||
|
|
||||||
class Printer(algorithm.Visitor):
|
class Printer(algorithm.Visitor):
|
||||||
|
@ -84,6 +84,7 @@ def main():
|
||||||
typed = ASTTypedRewriter(engine=engine, prelude=prelude.globals()).visit(parsed)
|
typed = ASTTypedRewriter(engine=engine, prelude=prelude.globals()).visit(parsed)
|
||||||
Inferencer(engine=engine).visit(typed)
|
Inferencer(engine=engine).visit(typed)
|
||||||
if monomorphize:
|
if monomorphize:
|
||||||
|
CastMonomorphizer(engine=engine).visit(typed)
|
||||||
IntMonomorphizer(engine=engine).visit(typed)
|
IntMonomorphizer(engine=engine).visit(typed)
|
||||||
Inferencer(engine=engine).visit(typed)
|
Inferencer(engine=engine).visit(typed)
|
||||||
if iodelay:
|
if iodelay:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from .asttyped_rewriter import ASTTypedRewriter
|
from .asttyped_rewriter import ASTTypedRewriter
|
||||||
from .inferencer import Inferencer
|
from .inferencer import Inferencer
|
||||||
from .int_monomorphizer import IntMonomorphizer
|
from .int_monomorphizer import IntMonomorphizer
|
||||||
|
from .cast_monomorphizer import CastMonomorphizer
|
||||||
from .iodelay_estimator import IODelayEstimator
|
from .iodelay_estimator import IODelayEstimator
|
||||||
from .artiq_ir_generator import ARTIQIRGenerator
|
from .artiq_ir_generator import ARTIQIRGenerator
|
||||||
from .dead_code_eliminator import DeadCodeEliminator
|
from .dead_code_eliminator import DeadCodeEliminator
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
"""
|
||||||
|
:class:`CastMonomorphizer` uses explicit casts to monomorphize
|
||||||
|
expressions of undetermined integer type to either 32 or 64 bits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pythonparser import algorithm, diagnostic
|
||||||
|
from .. import types, builtins
|
||||||
|
|
||||||
|
class CastMonomorphizer(algorithm.Visitor):
|
||||||
|
def __init__(self, engine):
|
||||||
|
self.engine = engine
|
||||||
|
|
||||||
|
def visit_CallT(self, node):
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
if (types.is_builtin(node.func.type, "int") or
|
||||||
|
types.is_builtin(node.func.type, "int32") or
|
||||||
|
types.is_builtin(node.func.type, "int64")):
|
||||||
|
typ = node.type.find()
|
||||||
|
if (not types.is_var(typ["width"]) and
|
||||||
|
builtins.is_int(node.args[0].type) and
|
||||||
|
types.is_var(node.args[0].type.find()["width"])):
|
||||||
|
node.args[0].type.unify(typ)
|
||||||
|
|
|
@ -780,7 +780,6 @@ class Inferencer(algorithm.Visitor):
|
||||||
elif types.is_builtin(typ, "round"):
|
elif types.is_builtin(typ, "round"):
|
||||||
valid_forms = lambda: [
|
valid_forms = lambda: [
|
||||||
valid_form("round(x:float) -> numpy.int?"),
|
valid_form("round(x:float) -> numpy.int?"),
|
||||||
valid_form("round(x:float, width=?) -> numpy.int?")
|
|
||||||
]
|
]
|
||||||
|
|
||||||
self._unify(node.type, builtins.TInt(),
|
self._unify(node.type, builtins.TInt(),
|
||||||
|
@ -791,19 +790,6 @@ class Inferencer(algorithm.Visitor):
|
||||||
|
|
||||||
self._unify(arg.type, builtins.TFloat(),
|
self._unify(arg.type, builtins.TFloat(),
|
||||||
arg.loc, None)
|
arg.loc, None)
|
||||||
elif len(node.args) == 1 and len(node.keywords) == 1 and \
|
|
||||||
builtins.is_numeric(node.args[0].type) and \
|
|
||||||
node.keywords[0].arg == 'width':
|
|
||||||
width = node.keywords[0].value
|
|
||||||
if not (isinstance(width, asttyped.NumT) and isinstance(width.n, int)):
|
|
||||||
diag = diagnostic.Diagnostic("error",
|
|
||||||
"the width argument of round() must be an integer literal", {},
|
|
||||||
node.keywords[0].loc)
|
|
||||||
self.engine.process(diag)
|
|
||||||
return
|
|
||||||
|
|
||||||
self._unify(node.type, builtins.TInt(types.TValue(width.n)),
|
|
||||||
node.loc, None)
|
|
||||||
else:
|
else:
|
||||||
diagnose(valid_forms())
|
diagnose(valid_forms())
|
||||||
elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"):
|
elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"):
|
||||||
|
|
|
@ -30,6 +30,3 @@ len([])
|
||||||
|
|
||||||
# CHECK-L: round:<function round>(1.0:float):numpy.int?
|
# CHECK-L: round:<function round>(1.0:float):numpy.int?
|
||||||
round(1.0)
|
round(1.0)
|
||||||
|
|
||||||
# CHECK-L: round:<function round>(1.0:float, width=64:numpy.int?):numpy.int64
|
|
||||||
round(1.0, width=64)
|
|
||||||
|
|
|
@ -6,6 +6,3 @@ x = 1
|
||||||
|
|
||||||
y = int(1)
|
y = int(1)
|
||||||
# CHECK-L: y: numpy.int32
|
# CHECK-L: y: numpy.int32
|
||||||
|
|
||||||
z = round(1.0)
|
|
||||||
# CHECK-L: z: numpy.int32
|
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
# RUN: %python -m artiq.compiler.testbench.inferencer +mono %s >%t
|
||||||
|
# RUN: OutputCheck %s --file-to-check=%t
|
||||||
|
|
||||||
|
# CHECK-L: round:<function round>(1.0:float):numpy.int32
|
||||||
|
round(1.0)
|
||||||
|
|
||||||
|
# CHECK-L: round:<function round>(2.0:float):numpy.int32
|
||||||
|
int32(round(2.0))
|
||||||
|
|
||||||
|
# CHECK-L: round:<function round>(3.0:float):numpy.int64
|
||||||
|
int64(round(3.0))
|
Loading…
Reference in New Issue