diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index 9e44e0da7..bbdc6462b 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -132,3 +132,7 @@ def is_exn_constructor(typ, name=None): typ.name == name else: return isinstance(typ, types.TExceptionConstructor) + +def is_mutable(typ): + return typ.fold(False, lambda accum, typ: + is_list(typ) or types.is_function(typ)) diff --git a/artiq/compiler/transforms/asttyped_rewriter.py b/artiq/compiler/transforms/asttyped_rewriter.py index 69e39e123..de72fce3e 100644 --- a/artiq/compiler/transforms/asttyped_rewriter.py +++ b/artiq/compiler/transforms/asttyped_rewriter.py @@ -28,32 +28,32 @@ class LocalExtractor(algorithm.Visitor): # parameters can't be declared as global or nonlocal self.params = set() - def visit_in_assign(self, node): + def visit_in_assign(self, node, in_assign): try: - self.in_assign = True + old_in_assign, self.in_assign = self.in_assign, in_assign return self.visit(node) finally: - self.in_assign = False + self.in_assign = old_in_assign def visit_Assign(self, node): self.visit(node.value) for target in node.targets: - self.visit_in_assign(target) + self.visit_in_assign(target, in_assign=True) def visit_For(self, node): self.visit(node.iter) - self.visit_in_assign(node.target) + self.visit_in_assign(node.target, in_assign=True) self.visit(node.body) self.visit(node.orelse) def visit_withitem(self, node): self.visit(node.context_expr) if node.optional_vars is not None: - self.visit_in_assign(node.optional_vars) + self.visit_in_assign(node.optional_vars, in_assign=True) def visit_comprehension(self, node): self.visit(node.iter) - self.visit_in_assign(node.target) + self.visit_in_assign(node.target, in_assign=True) for if_ in node.ifs: self.visit(node.ifs) @@ -99,6 +99,13 @@ class LocalExtractor(algorithm.Visitor): # creates a new binding for x in f's scope self._assignable(node.id) + def visit_Attribute(self, node): + self.visit_in_assign(node.value, in_assign=False) + + def visit_Subscript(self, node): + self.visit_in_assign(node.value, in_assign=False) + self.visit_in_assign(node.slice, in_assign=False) + def _check_not_in(self, name, names, curkind, newkind, loc): if name in names: diag = diagnostic.Diagnostic("error", diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 5733ac900..6694b6129 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -350,7 +350,6 @@ class Inferencer(algorithm.Visitor): return types.is_mono(opreand.type) and \ opreand.type.find().name == typ.find().name other_node = next(filter(wide_enough, operands)) - print(typ, other_node) node.left, *node.comparators = \ [self._coerce_one(typ, operand, other_node) for operand in operands] self._unify(node.type, builtins.TBool(), diff --git a/artiq/compiler/validators/__init__.py b/artiq/compiler/validators/__init__.py index d2ec51350..a90a89c69 100644 --- a/artiq/compiler/validators/__init__.py +++ b/artiq/compiler/validators/__init__.py @@ -1 +1,2 @@ from .monomorphism import MonomorphismValidator +from .escape import EscapeValidator diff --git a/artiq/compiler/validators/escape.py b/artiq/compiler/validators/escape.py index 4f02b1b7c..fd030eeef 100644 --- a/artiq/compiler/validators/escape.py +++ b/artiq/compiler/validators/escape.py @@ -3,8 +3,304 @@ the region of its allocation. """ +import functools from pythonparser import algorithm, diagnostic from .. import asttyped, types, builtins +class Region: + """ + A last-in-first-out allocation region. Tied to lexical scoping + and is internally represented simply by a source range. + + :ivar range: (:class:`pythonparser.source.Range` or None) + """ + + def __init__(self, source_range=None): + self.range = source_range + + def present(self): + return bool(self.range) + + def includes(self, other): + assert self.range + assert self.range.source_buffer == other.range.source_buffer + + return self.range.begin_pos <= other.range.begin_pos and \ + self.range.end_pos >= other.range.end_pos + + def intersects(self, other): + assert self.range.source_buffer == other.range.source_buffer + assert self.range + + return (self.range.begin_pos <= other.range.begin_pos <= self.range.end_pos and \ + other.range.end_pos > self.range.end_pos) or \ + (other.range.begin_pos <= self.range.begin_pos <= other.range.end_pos and \ + self.range.end_pos > other.range.end_pos) + + def contract(self, other): + if not self.range: + self.range = other.range + + def outlives(lhs, rhs): + if lhs is None: # lhs lives forever + return True + elif rhs is None: # rhs lives forever, lhs does not + return False + else: + assert not lhs.intersects(rhs) + return lhs.includes(rhs) + + def __repr__(self): + return "Region({})".format(repr(self.range)) + +class RegionOf(algorithm.Visitor): + """ + Visit an expression and return the list of regions that must + be alive for the expression to execute. + """ + + def __init__(self, env_stack, youngest_region): + self.env_stack, self.youngest_region = env_stack, youngest_region + + # Liveness determined by assignments + def visit_NameT(self, node): + # First, look at stack regions + for region in reversed(self.env_stack[1:]): + if node.id in region: + return region[node.id] + + # Then, look at the global region of this module + if node.id in self.env_stack[0]: + return None + + assert False + + # Value lives as long as the current scope, if it's mutable, + # or else forever + def visit_BinOpT(self, node): + if builtins.is_mutable(node.type): + return self.youngest_region + else: + return None + + # Value lives as long as the object/container, if it's mutable, + # or else forever + def visit_accessor(self, node): + if builtins.is_mutable(node.type): + return self.visit(node.value) + else: + return None + + visit_AttributeT = visit_accessor + visit_SubscriptT = visit_accessor + + # Value lives as long as the shortest living operand + def visit_selecting(self, nodes): + regions = [self.visit(node) for node in nodes] + regions = list(filter(lambda x: x, regions)) + if any(regions): + regions.sort(key=functools.cmp_to_key(Region.outlives), reverse=True) + return regions[0] + else: + return None + + def visit_BoolOpT(self, node): + return self.visit_selecting(node.values) + + def visit_IfExpT(self, node): + return self.visit_selecting([node.body, node.orelse]) + + def visit_TupleT(self, node): + return self.visit_selecting(node.elts) + + # Value lives as long as the current scope + def visit_allocating(self, node): + return self.youngest_region + + visit_DictT = visit_allocating + visit_DictCompT = visit_allocating + visit_GeneratorExpT = visit_allocating + visit_LambdaT = visit_allocating + visit_ListT = visit_allocating + visit_ListCompT = visit_allocating + visit_SetT = visit_allocating + visit_SetCompT = visit_allocating + visit_StrT = visit_allocating + + # Value lives forever + def visit_immutable(self, node): + assert not builtins.is_mutable(node.type) + return None + + visit_CompareT = visit_immutable + visit_EllipsisT = visit_immutable + visit_NameConstantT = visit_immutable + visit_NumT = visit_immutable + visit_UnaryOpT = visit_immutable + visit_CallT = visit_immutable + + # Not implemented + def visit_unimplemented(self, node): + assert False + + visit_StarredT = visit_unimplemented + visit_YieldT = visit_unimplemented + visit_YieldFromT = visit_unimplemented + + +class AssignedNamesOf(algorithm.Visitor): + """ + Visit an expression and return the list of names that appear + on the lhs of assignment, directly or through an accessor. + """ + + def visit_NameT(self, node): + return [node] + + def visit_accessor(self, node): + return self.visit(node.value) + + visit_AttributeT = visit_accessor + visit_SubscriptT = visit_accessor + + def visit_sequence(self, node): + return reduce(list.__add__, map(self.visit, node.elts)) + + visit_TupleT = visit_sequence + visit_ListT = visit_sequence + + def visit_StarredT(self, node): + assert False + + class EscapeValidator(algorithm.Visitor): - pass + def __init__(self, engine): + self.engine = engine + self.youngest_region = None + self.env_stack = [] + self.youngest_env = None + + def _region_of(self, expr): + return RegionOf(self.env_stack, self.youngest_region).visit(expr) + + def _names_of(self, expr): + return AssignedNamesOf().visit(expr) + + def _diagnostics_for(self, region, loc, descr="the value of the expression"): + if region: + return [ + diagnostic.Diagnostic("note", + "{descr} is alive from this point...", {"descr": descr}, + region.range.begin()), + diagnostic.Diagnostic("note", + "... to this point", {}, + region.range.end()) + ] + else: + return [ + diagnostic.Diagnostic("note", + "{descr} is alive forever", {"descr": descr}, + loc) + ] + + def visit_in_region(self, node, region): + try: + old_youngest_region = self.youngest_region + self.youngest_region = region + + old_youngest_env = self.youngest_env + self.youngest_env = {} + + for name in node.typing_env: + if builtins.is_mutable(node.typing_env[name]): + self.youngest_env[name] = Region(None) # not yet known + else: + self.youngest_env[name] = None # lives forever + self.env_stack.append(self.youngest_env) + + self.generic_visit(node) + finally: + self.env_stack.pop() + self.youngest_env = old_youngest_env + self.youngest_region = old_youngest_region + + def visit_ModuleT(self, node): + self.visit_in_region(node, None) + + def visit_FunctionDefT(self, node): + self.youngest_env[node.name] = self.youngest_region + self.visit_in_region(node, Region(node.loc)) + + # Only three ways for a pointer to escape: + # * Assigning or op-assigning it (we ensure an outlives relationship) + # * Returning it (we only allow returning values that live forever) + # * Raising it (we forbid raising mutable data) + # + # Literals doesn't count: a constructed object is always + # outlived by all its constituents. + # Closures don't count: see above. + # Calling functions doesn't count: arguments never outlive + # the function body. + + def visit_assignment(self, target, value, is_aug_assign=False): + target_region = self._region_of(target) + value_region = self._region_of(value) if not is_aug_assign else self.youngest_region + + # If this is a variable, we might need to contract the live range. + if value_region is not None: + for name in self._names_of(target): + region = self._region_of(name) + if region is not None: + region.contract(value_region) + + # The assigned value should outlive the assignee + if not Region.outlives(value_region, target_region): + if is_aug_assign: + target_desc = "the assignment target, allocated here," + else: + target_desc = "the assignment target" + note = diagnostic.Diagnostic("note", + "this expression has type {type}", + {"type": types.TypePrinter().name(value.type)}, + value.loc) + diag = diagnostic.Diagnostic("error", + "the assigned value does not outlive the assignment target", {}, + value.loc, [target.loc], + notes=self._diagnostics_for(target_region, target.loc, + target_desc) + + self._diagnostics_for(value_region, value.loc, + "the assigned value")) + self.engine.process(diag) + + def visit_Assign(self, node): + for target in node.targets: + self.visit_assignment(target, node.value) + + def visit_AugAssign(self, node): + if builtins.is_mutable(node.target.type): + # If the target is mutable, op-assignment will allocate + # in the youngest region. + self.visit_assignment(node.target, node.value, is_aug_assign=True) + + def visit_Return(self, node): + region = self._region_of(node.value) + if region: + note = diagnostic.Diagnostic("note", + "this expression has type {type}", + {"type": types.TypePrinter().name(node.value.type)}, + node.value.loc) + diag = diagnostic.Diagnostic("error", + "cannot return a mutable value that does not live forever", {}, + node.value.loc, notes=self._diagnostics_for(region, node.value.loc) + [note]) + self.engine.process(diag) + + def visit_Raise(self, node): + if builtins.is_mutable(node.exc.type): + note = diagnostic.Diagnostic("note", + "this expression has type {type}", + {"type": types.TypePrinter().name(node.exc.type)}, + node.exc.loc) + diag = diagnostic.Diagnostic("error", + "cannot raise a mutable value", {}, + node.exc.loc, notes=[note]) + self.engine.process(diag)