Add MonomorphismChecker.

This commit is contained in:
whitequark 2015-07-02 21:28:26 +03:00
parent 73a8f3c442
commit 02b41ea0f7
9 changed files with 123 additions and 7 deletions

View File

@ -33,7 +33,7 @@ class FunctionDefT(ast.FunctionDef, scoped):
class ModuleT(ast.Module, scoped):
pass
class ExceptHandlerT(ast.ExceptHandler, commontyped):
class ExceptHandlerT(ast.ExceptHandler):
_fields = ("filter", "name", "body") # rename ast.ExceptHandler.type
_types = ("name_type",)

View File

@ -7,28 +7,33 @@ from pythonparser import source, diagnostic, parse_buffer
from . import prelude, types, transforms
class Module:
def __init__(self, source_buffer, engine=diagnostic.Engine(all_errors_are_fatal=True)):
def __init__(self, source_buffer, engine=None):
if engine is None:
engine = diagnostic.Engine(all_errors_are_fatal=True)
asttyped_rewriter = transforms.ASTTypedRewriter(engine=engine)
inferencer = transforms.Inferencer(engine=engine)
int_monomorphizer = transforms.IntMonomorphizer(engine=engine)
monomorphism_checker = transforms.MonomorphismChecker(engine=engine)
parsetree, comments = parse_buffer(source_buffer, engine=engine)
typedtree = asttyped_rewriter.visit(parsetree)
inferencer.visit(typedtree)
int_monomorphizer.visit(typedtree)
inferencer.visit(typedtree)
monomorphism_checker.visit(typedtree)
self.name = os.path.basename(source_buffer.name)
self.globals = asttyped_rewriter.globals
@classmethod
def from_string(klass, source_string, name="input.py", first_line=1):
return klass(source.Buffer(source_string + "\n", name, first_line))
def from_string(klass, source_string, name="input.py", first_line=1, engine=None):
return klass(source.Buffer(source_string + "\n", name, first_line), engine=engine)
@classmethod
def from_filename(klass, filename):
def from_filename(klass, filename, engine=None):
with open(filename) as f:
return klass(source.Buffer(f.read(), filename, 1))
return klass(source.Buffer(f.read(), filename, 1), engine=engine)
def __repr__(self):
printer = types.TypePrinter()

View File

@ -70,6 +70,5 @@ def main():
printer.rewriter.remove(comment.loc)
print(printer.rewrite().source)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,25 @@
import sys, fileinput
from pythonparser import diagnostic
from .. import Module
def main():
if len(sys.argv) > 1 and sys.argv[1] == '+diag':
del sys.argv[1]
def process_diagnostic(diag):
print("\n".join(diag.render(only_line=True)))
if diag.level == 'fatal':
exit()
else:
def process_diagnostic(diag):
print("\n".join(diag.render()))
if diag.level in ('fatal', 'error'):
exit(1)
engine = diagnostic.Engine()
engine.process = process_diagnostic
mod = Module.from_string("".join(fileinput.input()).expandtabs(), engine=engine)
print(repr(mod))
if __name__ == "__main__":
main()

View File

@ -1,3 +1,4 @@
from .asttyped_rewriter import ASTTypedRewriter
from .inferencer import Inferencer
from .int_monomorphizer import IntMonomorphizer
from .monomorphism_checker import MonomorphismChecker

View File

@ -0,0 +1,39 @@
"""
:class:`MonomorphismChecker` verifies that all type variables have been
elided, which is necessary for code generation.
"""
from pythonparser import algorithm, diagnostic
from .. import asttyped, types, builtins
class MonomorphismChecker(algorithm.Visitor):
def __init__(self, engine):
self.engine = engine
def visit_FunctionDefT(self, node):
super().generic_visit(node)
return_type = node.signature_type.find().ret
if types.is_polymorphic(return_type):
note = diagnostic.Diagnostic("note",
"the function has return type {type}",
{"type": types.TypePrinter().name(return_type)},
node.name_loc)
diag = diagnostic.Diagnostic("error",
"the return type of this function cannot be fully inferred", {},
node.name_loc, notes=[note])
self.engine.process(diag)
def generic_visit(self, node):
super().generic_visit(node)
if isinstance(node, asttyped.commontyped):
if types.is_polymorphic(node.type):
note = diagnostic.Diagnostic("note",
"the expression has type {type}",
{"type": types.TypePrinter().name(node.type)},
node.loc)
diag = diagnostic.Diagnostic("error",
"the type of this expression cannot be fully inferred", {},
node.loc, notes=[note])
self.engine.process(diag)

View File

@ -56,6 +56,12 @@ class TVar(Type):
else:
self.find().unify(other)
def fold(self, accum, fn):
if self.parent is self:
return fn(accum, self)
else:
return self.find().fold(accum, fn)
def __repr__(self):
if self.parent is self:
return "<py2llvm.types.TVar %d>" % id(self)
@ -92,6 +98,11 @@ class TMono(Type):
else:
raise UnificationError(self, other)
def fold(self, accum, fn):
for param in self.params:
accum = self.params[param].fold(accum, fn)
return fn(accum, self)
def __repr__(self):
return "py2llvm.types.TMono(%s, %s)" % (repr(self.name), repr(self.params))
@ -131,6 +142,11 @@ class TTuple(Type):
else:
raise UnificationError(self, other)
def fold(self, accum, fn):
for elt in self.elts:
accum = elt.fold(accum, fn)
return fn(accum, self)
def __repr__(self):
return "py2llvm.types.TTuple(%s)" % repr(self.elts)
@ -177,6 +193,14 @@ class TFunction(Type):
else:
raise UnificationError(self, other)
def fold(self, accum, fn):
for arg in self.args:
accum = arg.fold(accum, fn)
for optarg in self.optargs:
accum = self.optargs[optarg].fold(accum, fn)
accum = self.ret.fold(accum, fn)
return fn(accum, self)
def __repr__(self):
return "py2llvm.types.TFunction(%s, %s, %s)" % \
(repr(self.args), repr(self.optargs), repr(self.ret))
@ -208,6 +232,9 @@ class TBuiltin(Type):
if self != other:
raise UnificationError(self, other)
def fold(self, accum, fn):
return fn(accum, self)
def __repr__(self):
return "py2llvm.types.TBuiltin(%s)" % repr(self.name)
@ -258,6 +285,9 @@ class TValue(Type):
elif self != other:
raise UnificationError(self, other)
def fold(self, accum, fn):
return fn(accum, self)
def __repr__(self):
return "py2llvm.types.TValue(%s)" % repr(self.value)
@ -281,6 +311,9 @@ def is_mono(typ, name=None, **params):
return isinstance(typ, TMono) and \
(name is None or (typ.name == name and params_match))
def is_polymorphic(typ):
return typ.fold(False, lambda accum, typ: accum or is_var(typ))
def is_tuple(typ, elts=None):
typ = typ.find()
if elts:

View File

@ -0,0 +1,9 @@
# RUN: %python -m artiq.compiler.testbench.module +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
# CHECK-L: ${LINE:+1}: error: the type of this expression cannot be fully inferred
x = int(1)
# CHECK-L: ${LINE:+1}: error: the return type of this function cannot be fully inferred
def fn():
return int(1)

View File

@ -0,0 +1,5 @@
# RUN: %python -m artiq.compiler.testbench.module %s >%t
# RUN: OutputCheck %s --file-to-check=%t
x = 1
# CHECK-L: x: int(width=32)