from abc import ABCMeta, abstractmethod from collections import OrderedDict from collections.abc import Iterable from .._utils import flatten, deprecated from .. import tracer from .ast import * from .ast import _StatementList from .cd import * from .ir import * from .rec import * __all__ = ["ValueVisitor", "ValueTransformer", "StatementVisitor", "StatementTransformer", "FragmentTransformer", "TransformedElaboratable", "DomainCollector", "DomainRenamer", "DomainLowerer", "SampleDomainInjector", "SampleLowerer", "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter", "ResetInserter", "EnableInserter"] class ValueVisitor(metaclass=ABCMeta): @abstractmethod def on_Const(self, value): pass # :nocov: @abstractmethod def on_AnyConst(self, value): pass # :nocov: @abstractmethod def on_AnySeq(self, value): pass # :nocov: @abstractmethod def on_Signal(self, value): pass # :nocov: @abstractmethod def on_Record(self, value): pass # :nocov: @abstractmethod def on_ClockSignal(self, value): pass # :nocov: @abstractmethod def on_ResetSignal(self, value): pass # :nocov: @abstractmethod def on_Operator(self, value): pass # :nocov: @abstractmethod def on_Slice(self, value): pass # :nocov: @abstractmethod def on_Part(self, value): pass # :nocov: @abstractmethod def on_Cat(self, value): pass # :nocov: @abstractmethod def on_Repl(self, value): pass # :nocov: @abstractmethod def on_ArrayProxy(self, value): pass # :nocov: @abstractmethod def on_Sample(self, value): pass # :nocov: @abstractmethod def on_Initial(self, value): pass # :nocov: def on_unknown_value(self, value): raise TypeError("Cannot transform value {!r}".format(value)) # :nocov: def replace_value_src_loc(self, value, new_value): return True def on_value(self, value): if type(value) is Const: new_value = self.on_Const(value) elif type(value) is AnyConst: new_value = self.on_AnyConst(value) elif type(value) is AnySeq: new_value = self.on_AnySeq(value) elif isinstance(value, Signal): # Uses `isinstance()` and not `type() is` because nmigen.compat requires it. new_value = self.on_Signal(value) elif isinstance(value, Record): # Uses `isinstance()` and not `type() is` to allow inheriting from Record. new_value = self.on_Record(value) elif type(value) is ClockSignal: new_value = self.on_ClockSignal(value) elif type(value) is ResetSignal: new_value = self.on_ResetSignal(value) elif type(value) is Operator: new_value = self.on_Operator(value) elif type(value) is Slice: new_value = self.on_Slice(value) elif type(value) is Part: new_value = self.on_Part(value) elif type(value) is Cat: new_value = self.on_Cat(value) elif type(value) is Repl: new_value = self.on_Repl(value) elif type(value) is ArrayProxy: new_value = self.on_ArrayProxy(value) elif type(value) is Sample: new_value = self.on_Sample(value) elif type(value) is Initial: new_value = self.on_Initial(value) elif isinstance(value, UserValue): # Uses `isinstance()` and not `type() is` to allow inheriting. new_value = self.on_value(value._lazy_lower()) else: new_value = self.on_unknown_value(value) if isinstance(new_value, Value) and self.replace_value_src_loc(value, new_value): new_value.src_loc = value.src_loc return new_value def __call__(self, value): return self.on_value(value) class ValueTransformer(ValueVisitor): def on_Const(self, value): return value def on_AnyConst(self, value): return value def on_AnySeq(self, value): return value def on_Signal(self, value): return value def on_Record(self, value): return value def on_ClockSignal(self, value): return value def on_ResetSignal(self, value): return value def on_Operator(self, value): return Operator(value.operator, [self.on_value(o) for o in value.operands]) def on_Slice(self, value): return Slice(self.on_value(value.value), value.start, value.stop) def on_Part(self, value): return Part(self.on_value(value.value), self.on_value(value.offset), value.width, value.stride) def on_Cat(self, value): return Cat(self.on_value(o) for o in value.parts) def on_Repl(self, value): return Repl(self.on_value(value.value), value.count) def on_ArrayProxy(self, value): return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()], self.on_value(value.index)) def on_Sample(self, value): return Sample(self.on_value(value.value), value.clocks, value.domain) def on_Initial(self, value): return value class StatementVisitor(metaclass=ABCMeta): @abstractmethod def on_Assign(self, stmt): pass # :nocov: @abstractmethod def on_Assert(self, stmt): pass # :nocov: @abstractmethod def on_Assume(self, stmt): pass # :nocov: @abstractmethod def on_Cover(self, stmt): pass # :nocov: @abstractmethod def on_Switch(self, stmt): pass # :nocov: @abstractmethod def on_statements(self, stmts): pass # :nocov: def on_unknown_statement(self, stmt): raise TypeError("Cannot transform statement {!r}".format(stmt)) # :nocov: def replace_statement_src_loc(self, stmt, new_stmt): return True def on_statement(self, stmt): if type(stmt) is Assign: new_stmt = self.on_Assign(stmt) elif type(stmt) is Assert: new_stmt = self.on_Assert(stmt) elif type(stmt) is Assume: new_stmt = self.on_Assume(stmt) elif type(stmt) is Cover: new_stmt = self.on_Cover(stmt) elif isinstance(stmt, Switch): # Uses `isinstance()` and not `type() is` because nmigen.compat requires it. new_stmt = self.on_Switch(stmt) elif isinstance(stmt, Iterable): new_stmt = self.on_statements(stmt) else: new_stmt = self.on_unknown_statement(stmt) if isinstance(new_stmt, Statement) and self.replace_statement_src_loc(stmt, new_stmt): new_stmt.src_loc = stmt.src_loc if isinstance(new_stmt, Switch) and isinstance(stmt, Switch): new_stmt.case_src_locs = stmt.case_src_locs if isinstance(new_stmt, Property): new_stmt._MustUse__used = True return new_stmt def __call__(self, stmt): return self.on_statement(stmt) class StatementTransformer(StatementVisitor): def on_value(self, value): return value def on_Assign(self, stmt): return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs)) def on_Assert(self, stmt): return Assert(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en) def on_Assume(self, stmt): return Assume(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en) def on_Cover(self, stmt): return Cover(self.on_value(stmt.test), _check=stmt._check, _en=stmt._en) def on_Switch(self, stmt): cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) return Switch(self.on_value(stmt.test), cases) def on_statements(self, stmts): return _StatementList(flatten(self.on_statement(stmt) for stmt in stmts)) class FragmentTransformer: def map_subfragments(self, fragment, new_fragment): for subfragment, name in fragment.subfragments: new_fragment.add_subfragment(self(subfragment), name) def map_ports(self, fragment, new_fragment): for port, dir in fragment.ports.items(): new_fragment.add_ports(port, dir=dir) def map_named_ports(self, fragment, new_fragment): if hasattr(self, "on_value"): for name, (value, dir) in fragment.named_ports.items(): new_fragment.named_ports[name] = self.on_value(value), dir else: new_fragment.named_ports = OrderedDict(fragment.named_ports.items()) def map_domains(self, fragment, new_fragment): for domain in fragment.iter_domains(): new_fragment.add_domains(fragment.domains[domain]) def map_statements(self, fragment, new_fragment): if hasattr(self, "on_statement"): new_fragment.add_statements(map(self.on_statement, fragment.statements)) else: new_fragment.add_statements(fragment.statements) def map_drivers(self, fragment, new_fragment): for domain, signal in fragment.iter_drivers(): new_fragment.add_driver(signal, domain) def on_fragment(self, fragment): if isinstance(fragment, Instance): new_fragment = Instance(fragment.type) new_fragment.parameters = OrderedDict(fragment.parameters) self.map_named_ports(fragment, new_fragment) else: new_fragment = Fragment() new_fragment.flatten = fragment.flatten new_fragment.attrs = OrderedDict(fragment.attrs) self.map_ports(fragment, new_fragment) self.map_subfragments(fragment, new_fragment) self.map_domains(fragment, new_fragment) self.map_statements(fragment, new_fragment) self.map_drivers(fragment, new_fragment) return new_fragment def __call__(self, value, *, src_loc_at=0): if isinstance(value, Fragment): return self.on_fragment(value) elif isinstance(value, TransformedElaboratable): value._transforms_.append(self) return value elif hasattr(value, "elaborate"): value = TransformedElaboratable(value, src_loc_at=1 + src_loc_at) value._transforms_.append(self) return value else: raise AttributeError("Object {!r} cannot be elaborated".format(value)) class TransformedElaboratable(Elaboratable): def __init__(self, elaboratable, *, src_loc_at=0): assert hasattr(elaboratable, "elaborate") # Fields prefixed and suffixed with underscore to avoid as many conflicts with the inner # object as possible, since we're forwarding attribute requests to it. self._elaboratable_ = elaboratable self._transforms_ = [] def __getattr__(self, attr): return getattr(self._elaboratable_, attr) def elaborate(self, platform): fragment = Fragment.get(self._elaboratable_, platform) for transform in self._transforms_: fragment = transform(fragment) return fragment class DomainCollector(ValueVisitor, StatementVisitor): def __init__(self): self.used_domains = set() self.defined_domains = set() self._local_domains = set() def _add_used_domain(self, domain_name): if domain_name is None: return if domain_name in self._local_domains: return self.used_domains.add(domain_name) def on_ignore(self, value): pass on_Const = on_ignore on_AnyConst = on_ignore on_AnySeq = on_ignore on_Signal = on_ignore def on_ClockSignal(self, value): self._add_used_domain(value.domain) def on_ResetSignal(self, value): self._add_used_domain(value.domain) on_Record = on_ignore def on_Operator(self, value): for o in value.operands: self.on_value(o) def on_Slice(self, value): self.on_value(value.value) def on_Part(self, value): self.on_value(value.value) self.on_value(value.offset) def on_Cat(self, value): for o in value.parts: self.on_value(o) def on_Repl(self, value): self.on_value(value.value) def on_ArrayProxy(self, value): for elem in value._iter_as_values(): self.on_value(elem) self.on_value(value.index) def on_Sample(self, value): self.on_value(value.value) def on_Initial(self, value): pass def on_Assign(self, stmt): self.on_value(stmt.lhs) self.on_value(stmt.rhs) def on_property(self, stmt): self.on_value(stmt.test) on_Assert = on_property on_Assume = on_property on_Cover = on_property def on_Switch(self, stmt): self.on_value(stmt.test) for stmts in stmt.cases.values(): self.on_statement(stmts) def on_statements(self, stmts): for stmt in stmts: self.on_statement(stmt) def on_fragment(self, fragment): if isinstance(fragment, Instance): for name, (value, dir) in fragment.named_ports.items(): self.on_value(value) old_local_domains, self._local_domains = self._local_domains, set(self._local_domains) for domain_name, domain in fragment.domains.items(): if domain.local: self._local_domains.add(domain_name) else: self.defined_domains.add(domain_name) self.on_statements(fragment.statements) for domain_name in fragment.drivers: self._add_used_domain(domain_name) for subfragment, name in fragment.subfragments: self.on_fragment(subfragment) self._local_domains = old_local_domains def __call__(self, fragment): self.on_fragment(fragment) class DomainRenamer(FragmentTransformer, ValueTransformer, StatementTransformer): def __init__(self, domain_map): if isinstance(domain_map, str): domain_map = {"sync": domain_map} for src, dst in domain_map.items(): if src == "comb": raise ValueError("Domain '{}' may not be renamed".format(src)) if dst == "comb": raise ValueError("Domain '{}' may not be renamed to '{}'".format(src, dst)) self.domain_map = OrderedDict(domain_map) def on_ClockSignal(self, value): if value.domain in self.domain_map: return ClockSignal(self.domain_map[value.domain]) return value def on_ResetSignal(self, value): if value.domain in self.domain_map: return ResetSignal(self.domain_map[value.domain]) return value def map_domains(self, fragment, new_fragment): for domain in fragment.iter_domains(): cd = fragment.domains[domain] if domain in self.domain_map: if cd.name == domain: # Rename the actual ClockDomain object. cd.rename(self.domain_map[domain]) else: assert cd.name == self.domain_map[domain] new_fragment.add_domains(cd) def map_drivers(self, fragment, new_fragment): for domain, signals in fragment.drivers.items(): if domain in self.domain_map: domain = self.domain_map[domain] for signal in signals: new_fragment.add_driver(self.on_value(signal), domain) class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer): def __init__(self, domains=None): self.domains = domains def _resolve(self, domain, context): if domain not in self.domains: raise DomainError("Signal {!r} refers to nonexistent domain '{}'" .format(context, domain)) return self.domains[domain] def map_drivers(self, fragment, new_fragment): for domain, signal in fragment.iter_drivers(): new_fragment.add_driver(self.on_value(signal), domain) def replace_value_src_loc(self, value, new_value): return not isinstance(value, (ClockSignal, ResetSignal)) def on_ClockSignal(self, value): domain = self._resolve(value.domain, value) return domain.clk def on_ResetSignal(self, value): domain = self._resolve(value.domain, value) if domain.rst is None: if value.allow_reset_less: return Const(0) else: raise DomainError("Signal {!r} refers to reset of reset-less domain '{}'" .format(value, value.domain)) return domain.rst def _insert_resets(self, fragment): for domain_name, signals in fragment.drivers.items(): if domain_name is None: continue domain = fragment.domains[domain_name] if domain.rst is None: continue stmts = [signal.eq(Const(signal.reset, signal.width)) for signal in signals if not signal.reset_less] fragment.add_statements(Switch(domain.rst, {1: stmts})) def on_fragment(self, fragment): self.domains = fragment.domains new_fragment = super().on_fragment(fragment) self._insert_resets(new_fragment) return new_fragment class SampleDomainInjector(ValueTransformer, StatementTransformer): def __init__(self, domain): self.domain = domain def on_Sample(self, value): if value.domain is not None: return value return Sample(value.value, value.clocks, self.domain) def __call__(self, stmts): return self.on_statement(stmts) class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer): def __init__(self): self.initial = None self.sample_cache = None self.sample_stmts = None def _name_reset(self, value): if isinstance(value, Const): return "c${}".format(value.value), value.value elif isinstance(value, Signal): return "s${}".format(value.name), value.reset elif isinstance(value, ClockSignal): return "clk", 0 elif isinstance(value, ResetSignal): return "rst", 1 elif isinstance(value, Initial): return "init", 0 # Past(Initial()) produces 0, 1, 0, 0, ... else: raise NotImplementedError # :nocov: def on_Sample(self, value): if value in self.sample_cache: return self.sample_cache[value] sampled_value = self.on_value(value.value) if value.clocks == 0: sample = sampled_value else: assert value.domain is not None sampled_name, sampled_reset = self._name_reset(value.value) name = "$sample${}${}${}".format(sampled_name, value.domain, value.clocks) sample = Signal.like(value.value, name=name, reset_less=True, reset=sampled_reset) sample.attrs["nmigen.sample_reg"] = True prev_sample = self.on_Sample(Sample(sampled_value, value.clocks - 1, value.domain)) if value.domain not in self.sample_stmts: self.sample_stmts[value.domain] = [] self.sample_stmts[value.domain].append(sample.eq(prev_sample)) self.sample_cache[value] = sample return sample def on_Initial(self, value): if self.initial is None: self.initial = Signal(name="init") return self.initial def map_statements(self, fragment, new_fragment): self.initial = None self.sample_cache = ValueDict() self.sample_stmts = OrderedDict() new_fragment.add_statements(map(self.on_statement, fragment.statements)) for domain, stmts in self.sample_stmts.items(): new_fragment.add_statements(stmts) for stmt in stmts: new_fragment.add_driver(stmt.lhs, domain) if self.initial is not None: new_fragment.add_subfragment(Instance("$initstate", o_Y=self.initial)) class SwitchCleaner(StatementVisitor): def on_ignore(self, stmt): return stmt on_Assign = on_ignore on_Assert = on_ignore on_Assume = on_ignore on_Cover = on_ignore def on_Switch(self, stmt): cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items()) if any(len(s) for s in cases.values()): return Switch(stmt.test, cases) def on_statements(self, stmts): stmts = flatten(self.on_statement(stmt) for stmt in stmts) return _StatementList(stmt for stmt in stmts if stmt is not None) class LHSGroupAnalyzer(StatementVisitor): def __init__(self): self.signals = SignalDict() self.unions = OrderedDict() def find(self, signal): if signal not in self.signals: self.signals[signal] = len(self.signals) group = self.signals[signal] while group in self.unions: group = self.unions[group] self.signals[signal] = group return group def unify(self, root, *leaves): root_group = self.find(root) for leaf in leaves: leaf_group = self.find(leaf) if root_group == leaf_group: continue self.unions[leaf_group] = root_group def groups(self): groups = OrderedDict() for signal in self.signals: group = self.find(signal) if group not in groups: groups[group] = SignalSet() groups[group].add(signal) return groups def on_Assign(self, stmt): lhs_signals = stmt._lhs_signals() if lhs_signals: self.unify(*stmt._lhs_signals()) def on_property(self, stmt): lhs_signals = stmt._lhs_signals() if lhs_signals: self.unify(*stmt._lhs_signals()) on_Assert = on_property on_Assume = on_property on_Cover = on_property def on_Switch(self, stmt): for case_stmts in stmt.cases.values(): self.on_statements(case_stmts) def on_statements(self, stmts): for stmt in stmts: self.on_statement(stmt) def __call__(self, stmts): self.on_statements(stmts) return self.groups() class LHSGroupFilter(SwitchCleaner): def __init__(self, signals): self.signals = signals def on_Assign(self, stmt): # The invariant provided by LHSGroupAnalyzer is that all signals that ever appear together # on LHS are a part of the same group, so it is sufficient to check any of them. lhs_signals = stmt.lhs._lhs_signals() if lhs_signals: any_lhs_signal = next(iter(lhs_signals)) if any_lhs_signal in self.signals: return stmt def on_property(self, stmt): any_lhs_signal = next(iter(stmt._lhs_signals())) if any_lhs_signal in self.signals: return stmt on_Assert = on_property on_Assume = on_property on_Cover = on_property class _ControlInserter(FragmentTransformer): def __init__(self, controls): self.src_loc = None if isinstance(controls, Value): controls = {"sync": controls} self.controls = OrderedDict(controls) def on_fragment(self, fragment): new_fragment = super().on_fragment(fragment) for domain, signals in fragment.drivers.items(): if domain is None or domain not in self.controls: continue self._insert_control(new_fragment, domain, signals) return new_fragment def _insert_control(self, fragment, domain, signals): raise NotImplementedError # :nocov: def __call__(self, value, *, src_loc_at=0): self.src_loc = tracer.get_src_loc(src_loc_at=src_loc_at) return super().__call__(value, src_loc_at=1 + src_loc_at) class ResetInserter(_ControlInserter): def _insert_control(self, fragment, domain, signals): stmts = [s.eq(Const(s.reset, s.width)) for s in signals if not s.reset_less] fragment.add_statements(Switch(self.controls[domain], {1: stmts}, src_loc=self.src_loc)) class EnableInserter(_ControlInserter): def _insert_control(self, fragment, domain, signals): stmts = [s.eq(s) for s in signals] fragment.add_statements(Switch(self.controls[domain], {0: stmts}, src_loc=self.src_loc)) def on_fragment(self, fragment): new_fragment = super().on_fragment(fragment) if isinstance(new_fragment, Instance) and new_fragment.type in ("$memrd", "$memwr"): clk_port, clk_dir = new_fragment.named_ports["CLK"] if isinstance(clk_port, ClockSignal) and clk_port.domain in self.controls: en_port, en_dir = new_fragment.named_ports["EN"] en_port = Mux(self.controls[clk_port.domain], en_port, Const(0, len(en_port))) new_fragment.named_ports["EN"] = en_port, en_dir return new_fragment