compiler.module: split off inferencing from Module.__init__.

This commit is contained in:
whitequark 2015-08-06 08:24:41 +03:00
parent ca52b2fdd0
commit 7562d39750
10 changed files with 58 additions and 41 deletions

View File

@ -1 +1 @@
from .module import Module from .module import Module, Source

View File

@ -1,36 +1,62 @@
""" """
The :class:`Module` class encapsulates a single Python The :class:`Module` class encapsulates a single Python module,
which corresponds to a single ARTIQ translation unit (one LLVM
bitcode file and one object file, unless LTO is used).
A :class:`Module` can be created from a typed AST.
The :class:`Source` class parses a single source file or
string and infers types for it using a trivial :module:`prelude`.
""" """
import os import os
from pythonparser import source, diagnostic, parse_buffer from pythonparser import source, diagnostic, parse_buffer
from . import prelude, types, transforms, validators from . import prelude, types, transforms, validators
class Module: class Source:
def __init__(self, source_buffer, engine=None): def __init__(self, source_buffer, engine=None):
if engine is None: if engine is None:
engine = diagnostic.Engine(all_errors_are_fatal=True) self.engine = diagnostic.Engine(all_errors_are_fatal=True)
else:
self.engine = engine
self.name, _ = os.path.splitext(os.path.basename(source_buffer.name)) self.name, _ = os.path.splitext(os.path.basename(source_buffer.name))
asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine) asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine,
globals=prelude.globals())
inferencer = transforms.Inferencer(engine=engine) inferencer = transforms.Inferencer(engine=engine)
int_monomorphizer = transforms.IntMonomorphizer(engine=engine)
monomorphism_validator = validators.MonomorphismValidator(engine=engine)
escape_validator = validators.EscapeValidator(engine=engine)
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=engine, module_name=self.name)
dead_code_eliminator = transforms.DeadCodeEliminator(engine=engine)
local_access_validator = validators.LocalAccessValidator(engine=engine)
self.parsetree, self.comments = parse_buffer(source_buffer, engine=engine) self.parsetree, self.comments = parse_buffer(source_buffer, engine=engine)
self.typedtree = asttyped_rewriter.visit(self.parsetree) self.typedtree = asttyped_rewriter.visit(self.parsetree)
self.globals = asttyped_rewriter.globals self.globals = asttyped_rewriter.globals
inferencer.visit(self.typedtree) inferencer.visit(self.typedtree)
int_monomorphizer.visit(self.typedtree)
inferencer.visit(self.typedtree) @classmethod
monomorphism_validator.visit(self.typedtree) def from_string(cls, source_string, name="input.py", first_line=1, engine=None):
escape_validator.visit(self.typedtree) return cls(source.Buffer(source_string + "\n", name, first_line), engine=engine)
self.artiq_ir = artiq_ir_generator.visit(self.typedtree)
@classmethod
def from_filename(cls, filename, engine=None):
with open(filename) as f:
return cls(source.Buffer(f.read(), filename, 1), engine=engine)
class Module:
def __init__(self, src):
int_monomorphizer = transforms.IntMonomorphizer(engine=src.engine)
inferencer = transforms.Inferencer(engine=src.engine)
monomorphism_validator = validators.MonomorphismValidator(engine=src.engine)
escape_validator = validators.EscapeValidator(engine=src.engine)
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=src.engine,
module_name=src.name)
dead_code_eliminator = transforms.DeadCodeEliminator(engine=src.engine)
local_access_validator = validators.LocalAccessValidator(engine=src.engine)
self.name = src.name
self.globals = src.globals
int_monomorphizer.visit(src.typedtree)
inferencer.visit(src.typedtree)
monomorphism_validator.visit(src.typedtree)
escape_validator.visit(src.typedtree)
self.artiq_ir = artiq_ir_generator.visit(src.typedtree)
dead_code_eliminator.process(self.artiq_ir) dead_code_eliminator.process(self.artiq_ir)
# local_access_validator.process(self.artiq_ir) # local_access_validator.process(self.artiq_ir)
@ -43,15 +69,6 @@ class Module:
"""Return the name of the function that is the entry point of this module.""" """Return the name of the function that is the entry point of this module."""
return self.name + ".__modinit__" return self.name + ".__modinit__"
@classmethod
def from_string(cls, source_string, name="input.py", first_line=1, engine=None):
return cls(source.Buffer(source_string + "\n", name, first_line), engine=engine)
@classmethod
def from_filename(cls, filename, engine=None):
with open(filename) as f:
return cls(source.Buffer(f.read(), filename, 1), engine=engine)
def __repr__(self): def __repr__(self):
printer = types.TypePrinter() printer = types.TypePrinter()
globals = ["%s: %s" % (var, printer.name(self.globals[var])) for var in self.globals] globals = ["%s: %s" % (var, printer.name(self.globals[var])) for var in self.globals]

View File

@ -66,7 +66,7 @@ def main():
buf = source.Buffer("".join(fileinput.input()).expandtabs(), buf = source.Buffer("".join(fileinput.input()).expandtabs(),
os.path.basename(fileinput.filename())) os.path.basename(fileinput.filename()))
parsed, comments = parse_buffer(buf, engine=engine) parsed, comments = parse_buffer(buf, engine=engine)
typed = ASTTypedRewriter(engine=engine).visit(parsed) typed = ASTTypedRewriter(engine=engine, globals=prelude.globals()).visit(parsed)
Inferencer(engine=engine).visit(typed) Inferencer(engine=engine).visit(typed)
if monomorphize: if monomorphize:
IntMonomorphizer(engine=engine).visit(typed) IntMonomorphizer(engine=engine).visit(typed)

View File

@ -1,6 +1,6 @@
import sys, fileinput import sys, fileinput
from pythonparser import diagnostic from pythonparser import diagnostic
from .. import Module from .. import Module, Source
def main(): def main():
def process_diagnostic(diag): def process_diagnostic(diag):
@ -11,7 +11,7 @@ def main():
engine = diagnostic.Engine() engine = diagnostic.Engine()
engine.process = process_diagnostic engine.process = process_diagnostic
mod = Module.from_string("".join(fileinput.input()).expandtabs(), engine=engine) mod = Module(Source.from_string("".join(fileinput.input()).expandtabs(), engine=engine))
for fn in mod.artiq_ir: for fn in mod.artiq_ir:
print(fn) print(fn)

View File

@ -1,7 +1,7 @@
import os, sys, fileinput, ctypes import os, sys, fileinput, ctypes
from pythonparser import diagnostic from pythonparser import diagnostic
from llvmlite_artiq import binding as llvm from llvmlite_artiq import binding as llvm
from .. import Module from .. import Module, Source
from ..targets import NativeTarget from ..targets import NativeTarget
def main(): def main():
@ -19,7 +19,7 @@ def main():
source = "".join(fileinput.input()) source = "".join(fileinput.input())
source = source.replace("#ARTIQ#", "") source = source.replace("#ARTIQ#", "")
mod = Module.from_string(source.expandtabs(), engine=engine) mod = Module(Source.from_string(source.expandtabs(), engine=engine))
target = NativeTarget() target = NativeTarget()
llmod = mod.build_llvm_ir(target) llmod = mod.build_llvm_ir(target)

View File

@ -1,7 +1,7 @@
import sys, fileinput import sys, fileinput
from pythonparser import diagnostic from pythonparser import diagnostic
from llvmlite_artiq import ir as ll from llvmlite_artiq import ir as ll
from .. import Module from .. import Module, Source
from ..targets import NativeTarget from ..targets import NativeTarget
def main(): def main():
@ -13,7 +13,7 @@ def main():
engine = diagnostic.Engine() engine = diagnostic.Engine()
engine.process = process_diagnostic engine.process = process_diagnostic
mod = Module.from_string("".join(fileinput.input()).expandtabs(), engine=engine) mod = Module(Source.from_string("".join(fileinput.input()).expandtabs(), engine=engine))
target = NativeTarget() target = NativeTarget()
llmod = mod.build_llvm_ir(target=target) llmod = mod.build_llvm_ir(target=target)

View File

@ -1,6 +1,6 @@
import sys, os, time, cProfile as profile, pstats import sys, os, time, cProfile as profile, pstats
from pythonparser import diagnostic from pythonparser import diagnostic
from .. import Module from .. import Module, Source
from ..targets import OR1KTarget from ..targets import OR1KTarget
def main(): def main():
@ -17,7 +17,7 @@ def main():
engine.process = process_diagnostic engine.process = process_diagnostic
# Make sure everything's valid # Make sure everything's valid
modules = [Module.from_filename(filename, engine=engine) modules = [Module(Source.from_filename(filename, engine=engine))
for filename in sys.argv[1:]] for filename in sys.argv[1:]]
def benchmark(f, name): def benchmark(f, name):

View File

@ -1,6 +1,6 @@
import sys, os import sys, os
from pythonparser import diagnostic from pythonparser import diagnostic
from .. import Module from .. import Module, Source
from ..targets import OR1KTarget from ..targets import OR1KTarget
def main(): def main():
@ -18,7 +18,7 @@ def main():
modules = [] modules = []
for filename in sys.argv[1:]: for filename in sys.argv[1:]:
modules.append(Module.from_filename(filename, engine=engine)) modules.append(Module(Source.from_filename(filename, engine=engine)))
llobj = OR1KTarget().compile_and_link(modules) llobj = OR1KTarget().compile_and_link(modules)

View File

@ -1,6 +1,6 @@
import sys, fileinput import sys, fileinput
from pythonparser import diagnostic from pythonparser import diagnostic
from .. import Module from .. import Module, Source
def main(): def main():
if len(sys.argv) > 1 and sys.argv[1] == "+diag": if len(sys.argv) > 1 and sys.argv[1] == "+diag":
@ -21,7 +21,7 @@ def main():
engine.process = process_diagnostic engine.process = process_diagnostic
try: try:
mod = Module.from_string("".join(fileinput.input()).expandtabs(), engine=engine) mod = Module(Source.from_string("".join(fileinput.input()).expandtabs(), engine=engine))
print(repr(mod)) print(repr(mod))
except: except:
if not diag: raise if not diag: raise

View File

@ -4,7 +4,7 @@ to a typedtree (:mod:`..asttyped`).
""" """
from pythonparser import algorithm, diagnostic from pythonparser import algorithm, diagnostic
from .. import asttyped, types, builtins, prelude from .. import asttyped, types, builtins
# This visitor will be called for every node with a scope, # This visitor will be called for every node with a scope,
# i.e.: class, function, comprehension, lambda # i.e.: class, function, comprehension, lambda
@ -185,10 +185,10 @@ class ASTTypedRewriter(algorithm.Transformer):
via :class:`LocalExtractor`. via :class:`LocalExtractor`.
""" """
def __init__(self, engine): def __init__(self, engine, globals):
self.engine = engine self.engine = engine
self.globals = None self.globals = None
self.env_stack = [prelude.globals()] self.env_stack = [globals]
def _find_name(self, name, loc): def _find_name(self, name, loc):
for typing_env in reversed(self.env_stack): for typing_env in reversed(self.env_stack):