compiler.embedding: use typedtree hash to iterate inference to fixpoint.

This commit is contained in:
whitequark 2015-08-27 17:04:28 -05:00
parent a3284f8978
commit 9791cbba4d
1 changed files with 28 additions and 5 deletions

View File

@ -8,7 +8,7 @@ annotated as ``@kernel`` when they are referenced.
import os, re, linecache, inspect, textwrap import os, re, linecache, inspect, textwrap
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from pythonparser import ast, source, diagnostic, parse_buffer from pythonparser import ast, algorithm, source, diagnostic, parse_buffer
from . import types, builtins, asttyped, prelude from . import types, builtins, asttyped, prelude
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
@ -101,6 +101,7 @@ class ASTSynthesizer:
constructor_type = types.TConstructor(instance_type) constructor_type = types.TConstructor(instance_type)
constructor_type.attributes['__objectid__'] = builtins.TInt(types.TValue(32)) constructor_type.attributes['__objectid__'] = builtins.TInt(types.TValue(32))
instance_type.constructor = constructor_type
self.type_map[typ] = instance_type, constructor_type self.type_map[typ] = instance_type, constructor_type
@ -253,6 +254,24 @@ class StitchingInferencer(Inferencer):
super().visit_AttributeT(node) super().visit_AttributeT(node)
class TypedtreeHasher(algorithm.Visitor):
def generic_visit(self, node):
def freeze(obj):
if isinstance(obj, ast.AST):
return self.visit(obj)
elif isinstance(obj, types.Type):
return hash(obj.find())
elif isinstance(obj, list):
return tuple(obj)
else:
assert obj is None or isinstance(obj, (bool, int, float, str))
return obj
fields = node._fields
if hasattr(node, '_types'):
fields = fields + node._types
return hash(tuple(freeze(getattr(node, field_name)) for field_name in fields))
class Stitcher: class Stitcher:
def __init__(self, engine=None): def __init__(self, engine=None):
if engine is None: if engine is None:
@ -275,12 +294,17 @@ class Stitcher:
inferencer = StitchingInferencer(engine=self.engine, inferencer = StitchingInferencer(engine=self.engine,
value_map=self.value_map, value_map=self.value_map,
quote=self._quote) quote=self._quote)
hasher = TypedtreeHasher()
# Iterate inference to fixed point. # Iterate inference to fixed point.
self.inference_finished = False old_typedtree_hash = None
while not self.inference_finished: while True:
self.inference_finished = True
inferencer.visit(self.typedtree) inferencer.visit(self.typedtree)
typedtree_hash = hasher.visit(self.typedtree)
if old_typedtree_hash == typedtree_hash:
break
old_typedtree_hash = typedtree_hash
# After we have found all functions, synthesize a module to hold them. # After we have found all functions, synthesize a module to hold them.
source_buffer = source.Buffer("", "<synthesized>") source_buffer = source.Buffer("", "<synthesized>")
@ -488,7 +512,6 @@ class Stitcher:
# the final call. # the final call.
function_node = self._quote_embedded_function(function) function_node = self._quote_embedded_function(function)
self.typedtree.insert(0, function_node) self.typedtree.insert(0, function_node)
self.inference_finished = False
return function_node.name return function_node.name
elif function.artiq_embedded.syscall is not None: elif function.artiq_embedded.syscall is not None:
# Insert a storage-less global whose type instructs the compiler # Insert a storage-less global whose type instructs the compiler