validators.escape: track region of arguments.

Fixes #232.
This commit is contained in:
whitequark 2016-03-18 03:08:14 +00:00
parent 9492464ed9
commit ac5061c205
3 changed files with 77 additions and 21 deletions

View File

@ -10,6 +10,17 @@ from .. import asttyped, types, builtins
def has_region(typ): def has_region(typ):
return typ.fold(False, lambda accum, typ: accum or builtins.is_allocated(typ)) return typ.fold(False, lambda accum, typ: accum or builtins.is_allocated(typ))
class Global:
def __repr__(self):
return "Global()"
class Argument:
def __init__(self, loc):
self.loc = loc
def __repr__(self):
return "Argument()"
class Region: class Region:
""" """
A last-in-first-out allocation region. Tied to lexical scoping A last-in-first-out allocation region. Tied to lexical scoping
@ -45,9 +56,9 @@ class Region:
self.range = other.range self.range = other.range
def outlives(lhs, rhs): def outlives(lhs, rhs):
if lhs is None: # lhs lives forever if not isinstance(lhs, Region): # lhs lives nonlexically
return True return True
elif rhs is None: # rhs lives forever, lhs does not elif not isinstance(rhs, Region): # rhs lives nonlexically, lhs does not
return False return False
else: else:
assert not lhs.intersects(rhs) assert not lhs.intersects(rhs)
@ -74,7 +85,7 @@ class RegionOf(algorithm.Visitor):
# Then, look at the global region of this module # Then, look at the global region of this module
if node.id in self.env_stack[0]: if node.id in self.env_stack[0]:
return None return Global()
assert False assert False
@ -84,14 +95,14 @@ class RegionOf(algorithm.Visitor):
if has_region(node.type): if has_region(node.type):
return self.youngest_region return self.youngest_region
else: else:
return None return Global()
visit_BinOpT = visit_sometimes_allocating visit_BinOpT = visit_sometimes_allocating
def visit_CallT(self, node): def visit_CallT(self, node):
if types.is_c_function(node.func.type, "cache_get"): if types.is_c_function(node.func.type, "cache_get"):
# The cache is borrow checked dynamically # The cache is borrow checked dynamically
return None return Global()
else: else:
self.visit_sometimes_allocating(node) self.visit_sometimes_allocating(node)
@ -101,7 +112,7 @@ class RegionOf(algorithm.Visitor):
if has_region(node.type): if has_region(node.type):
return self.visit(node.value) return self.visit(node.value)
else: else:
return None return Global()
visit_AttributeT = visit_accessor visit_AttributeT = visit_accessor
visit_SubscriptT = visit_accessor visit_SubscriptT = visit_accessor
@ -114,7 +125,7 @@ class RegionOf(algorithm.Visitor):
regions.sort(key=functools.cmp_to_key(Region.outlives), reverse=True) regions.sort(key=functools.cmp_to_key(Region.outlives), reverse=True)
return regions[0] return regions[0]
else: else:
return None return Global()
def visit_BoolOpT(self, node): def visit_BoolOpT(self, node):
return self.visit_selecting(node.values) return self.visit_selecting(node.values)
@ -141,7 +152,7 @@ class RegionOf(algorithm.Visitor):
# Value lives forever # Value lives forever
def visit_immutable(self, node): def visit_immutable(self, node):
assert not has_region(node.type) assert not has_region(node.type)
return None return Global()
visit_NameConstantT = visit_immutable visit_NameConstantT = visit_immutable
visit_NumT = visit_immutable visit_NumT = visit_immutable
@ -149,9 +160,12 @@ class RegionOf(algorithm.Visitor):
visit_UnaryOpT = visit_immutable visit_UnaryOpT = visit_immutable
visit_CompareT = visit_immutable visit_CompareT = visit_immutable
# Value is mutable, but still lives forever # Value lives forever
def visit_StrT(self, node): def visit_global(self, node):
return None return Global()
visit_StrT = visit_global
visit_QuoteT = visit_global
# Not implemented # Not implemented
def visit_unimplemented(self, node): def visit_unimplemented(self, node):
@ -168,9 +182,12 @@ class AssignedNamesOf(algorithm.Visitor):
on the lhs of assignment, directly or through an accessor. on the lhs of assignment, directly or through an accessor.
""" """
def visit_NameT(self, node): def visit_name(self, node):
return [node] return [node]
visit_NameT = visit_name
visit_QuoteT = visit_name
def visit_accessor(self, node): def visit_accessor(self, node):
return self.visit(node.value) return self.visit(node.value)
@ -190,7 +207,7 @@ class AssignedNamesOf(algorithm.Visitor):
class EscapeValidator(algorithm.Visitor): class EscapeValidator(algorithm.Visitor):
def __init__(self, engine): def __init__(self, engine):
self.engine = engine self.engine = engine
self.youngest_region = None self.youngest_region = Global()
self.env_stack = [] self.env_stack = []
self.youngest_env = None self.youngest_env = None
@ -201,7 +218,7 @@ class EscapeValidator(algorithm.Visitor):
return AssignedNamesOf().visit(expr) return AssignedNamesOf().visit(expr)
def _diagnostics_for(self, region, loc, descr="the value of the expression"): def _diagnostics_for(self, region, loc, descr="the value of the expression"):
if region: if isinstance(region, Region):
return [ return [
diagnostic.Diagnostic("note", diagnostic.Diagnostic("note",
"{descr} is alive from this point...", {"descr": descr}, "{descr} is alive from this point...", {"descr": descr},
@ -210,12 +227,23 @@ class EscapeValidator(algorithm.Visitor):
"... to this point", {}, "... to this point", {},
region.range.end()) region.range.end())
] ]
else: elif isinstance(region, Global):
return [ return [
diagnostic.Diagnostic("note", diagnostic.Diagnostic("note",
"{descr} is alive forever", {"descr": descr}, "{descr} is alive forever", {"descr": descr},
loc) loc)
] ]
elif isinstance(region, Argument):
return [
diagnostic.Diagnostic("note",
"{descr} is still alive after this function returns", {"descr": descr},
loc),
diagnostic.Diagnostic("note",
"{descr} is introduced here as a formal argument", {"descr": descr},
region.loc)
]
else:
assert False
def visit_in_region(self, node, region, typing_env, args=[]): def visit_in_region(self, node, region, typing_env, args=[]):
try: try:
@ -228,11 +256,11 @@ class EscapeValidator(algorithm.Visitor):
for name in typing_env: for name in typing_env:
if has_region(typing_env[name]): if has_region(typing_env[name]):
if name in args: if name in args:
self.youngest_env[name] = self.youngest_region self.youngest_env[name] = args[name]
else: else:
self.youngest_env[name] = Region(None) # not yet known self.youngest_env[name] = Region(None) # not yet known
else: else:
self.youngest_env[name] = None # lives forever self.youngest_env[name] = Global()
self.env_stack.append(self.youngest_env) self.env_stack.append(self.youngest_env)
self.generic_visit(node) self.generic_visit(node)
@ -247,7 +275,7 @@ class EscapeValidator(algorithm.Visitor):
def visit_FunctionDefT(self, node): def visit_FunctionDefT(self, node):
self.youngest_env[node.name] = self.youngest_region self.youngest_env[node.name] = self.youngest_region
self.visit_in_region(node, Region(node.loc), node.typing_env, self.visit_in_region(node, Region(node.loc), node.typing_env,
args=node.signature_type.find().arg_names()) args={ arg.arg: Argument(arg.loc) for arg in node.args.args })
def visit_ClassDefT(self, node): def visit_ClassDefT(self, node):
self.youngest_env[node.name] = self.youngest_region self.youngest_env[node.name] = self.youngest_region
@ -272,10 +300,10 @@ class EscapeValidator(algorithm.Visitor):
value_region = self._region_of(value) if not is_aug_assign else self.youngest_region value_region = self._region_of(value) if not is_aug_assign else self.youngest_region
# If this is a variable, we might need to contract the live range. # If this is a variable, we might need to contract the live range.
if value_region is not None: if isinstance(value_region, Region):
for name in self._names_of(target): for name in self._names_of(target):
region = self._region_of(name) region = self._region_of(name)
if region is not None: if isinstance(region, Region):
region.contract(value_region) region.contract(value_region)
# If we assign to an attribute of a quoted value, there will be no names # If we assign to an attribute of a quoted value, there will be no names
@ -315,7 +343,7 @@ class EscapeValidator(algorithm.Visitor):
def visit_Return(self, node): def visit_Return(self, node):
region = self._region_of(node.value) region = self._region_of(node.value)
if region: if isinstance(region, Region):
note = diagnostic.Diagnostic("note", note = diagnostic.Diagnostic("note",
"this expression has type {type}", "this expression has type {type}",
{"type": types.TypePrinter().name(node.value.type)}, {"type": types.TypePrinter().name(node.value.type)},

View File

@ -0,0 +1,14 @@
{
"comm": {
"type": "local",
"module": "artiq.coredevice.comm_dummy",
"class": "Comm",
"arguments": {}
},
"core": {
"type": "local",
"module": "artiq.coredevice.core",
"class": "Core",
"arguments": {"ref_period": 1e-9}
}
}

View File

@ -0,0 +1,14 @@
# RUN: %python -m artiq.compiler.testbench.embedding +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t
from artiq.experiment import *
class c:
x = []
cc = c()
@kernel
def entrypoint():
# CHECK-L: ${LINE:+1}: error: the assigned value does not outlive the assignment target
cc.x = [1]