From 5a9fa937ee035500e2ff3fd69ac520e18b4c04f4 Mon Sep 17 00:00:00 2001 From: morgan Date: Thu, 25 Jan 2024 16:33:13 +0800 Subject: [PATCH] sim: refactor into OOP & add FW collector --- src/wrpll_simulation/config.py | 81 ++++++ src/wrpll_simulation/sim.py | 387 --------------------------- src/wrpll_simulation/timesim.py | 127 +++++++++ src/wrpll_simulation/timesim_node.py | 220 +++++++++++++++ src/wrpll_simulation/wave_gen.py | 27 -- src/wrpll_simulation/wrpll.py | 117 -------- 6 files changed, 428 insertions(+), 531 deletions(-) create mode 100644 src/wrpll_simulation/config.py delete mode 100644 src/wrpll_simulation/sim.py create mode 100644 src/wrpll_simulation/timesim.py create mode 100644 src/wrpll_simulation/timesim_node.py delete mode 100644 src/wrpll_simulation/wave_gen.py delete mode 100644 src/wrpll_simulation/wrpll.py diff --git a/src/wrpll_simulation/config.py b/src/wrpll_simulation/config.py new file mode 100644 index 0000000..ae54754 --- /dev/null +++ b/src/wrpll_simulation/config.py @@ -0,0 +1,81 @@ +from numba.core.types import * +from numba.experimental import jitclass + + +@jitclass +class PI_Config(object): + KP: int64 + KI: int64 + + def __init__(self, KP, KI): + self.KP = KP + self.KI = KI + + +@jitclass +class Timesim_Config(object): + timestep_size: float32 + sim_length: int64 + helper_PI: PI_Config + main_PI: PI_Config + has_jitter: bool + + step_input_time: float64 + step_frequency: float64 + step_phase: float64 + + # preset + beating_period: int64 + blind_period: int16 + adpll_limit: int32 + + gtx_init_phase: float64 + gtx_init_freq: float64 + gtx_jitter: float64 + + helper_init_phase: float64 + main_init_phase: float64 + helper_init_freq: float64 + main_init_freq: float64 + dcxo_jitter: float64 + + irq_delay: float64 + i2c_comm_delay: float64 + dcxo_settling_delay: float64 + + # jitclass does not support **kwargs + def __init__( + self, + timestep_size: float32, + sim_length: int64, + helper_PI: PI_Config, + main_PI: PI_Config, + has_jitter: bool, + ): + # jitter << timestep_size + # otherwise simulation will add negative phase to Phase_Accumlator + self.timestep_size = timestep_size + self.sim_length = sim_length + self.helper_PI = helper_PI + self.main_PI = main_PI + self.has_jitter = has_jitter + + # preset + self.beating_period = 32768 + self.blind_period = 200 + self.adpll_limit = 8161512 + + self.gtx_init_phase = 0.0 + self.gtx_init_freq = 125_000_000 + self.gtx_jitter = 200e-17 + + self.helper_init_phase = 0.0 + self.main_init_phase = 0.0 + self.helper_init_freq = 125_000_000 * (1 - (1 / self.beating_period)) + self.main_init_freq = 125_000_000 + self.dcxo_jitter = 95e-17 + + # hardware delay + self.irq_delay = 2e-6 # ~2us for the interrupt handling + self.i2c_comm_delay = 85.6e-6 + self.dcxo_settling_delay = 100e-6 diff --git a/src/wrpll_simulation/sim.py b/src/wrpll_simulation/sim.py deleted file mode 100644 index 270baca..0000000 --- a/src/wrpll_simulation/sim.py +++ /dev/null @@ -1,387 +0,0 @@ -import numpy as np -from numba import njit -from wrpll_simulation.wave_gen import square - - -@njit -def simulation_jit( - time, - gtx_freq, - gtx_jitter, - helper_pll, - main_pll, - h_KP, - h_KI, - h_KD, - m_KP, - m_KI, - m_KD, - dcxo_freq, - h_jitter, - m_jitter, - base_adpll, - N, - adpll_write_period, - i2c_comm_delay, - dcxo_settling_delay, - blind_period, - start_up_delay, - helper_init_freq=0, -): - arr_len = len(time) - - main = np.zeros(arr_len, dtype=np.int8) - helper = np.zeros(arr_len, dtype=np.int8) - gtx = np.zeros(arr_len, dtype=np.int8) - - gtx_tag = 0 - gtx_ready = 0 - gtx_FF = 0 - gtx_beating = np.zeros(arr_len, dtype=np.int8) - - main_tag = 0 - main_ready = 0 - main_FF = 0 - main_beating = np.zeros(arr_len, dtype=np.int8) - - phase_err_arr = np.zeros(arr_len, dtype=np.int16) - period_err_arr = np.zeros(arr_len, dtype=np.int16) - helper_adpll_arr = np.zeros(arr_len, dtype=np.int32) - main_adpll_arr = np.zeros(arr_len, dtype=np.int32) - - helperfreq = np.zeros(arr_len, dtype=np.int32) - mainfreq = np.zeros(arr_len, dtype=np.int32) - - phase_collector_r = 0 - colr_gtx_tag = colr_main_tag = 0 - phase_collector_state = 0 - - period_collector_r = 0 - colr_last_gtx_tag = 0 - beating_period = 0 - period_collector_state = 0 - - # initial condition - timestep = time[1] - time[0] - gtx_phase = 0 - h_phase = np.random.uniform(0, 360) - m_phase = np.random.uniform(0, 360) - - if main_pll and not helper_pll: - helper_freq = helper_init_freq - else: - helper_freq = ( - dcxo_freq * (1 + base_adpll * 0.0001164 / 1_000_000) * ((N - 1) / N) - ) - main_freq = dcxo_freq * (1 + base_adpll * 0.0001164 / 1_000_000) - - last_gtx_tag = last_gtx_FF = last_gtx_beat = 0 - last_main_tag = last_main_FF = last_main_beat = 0 - - last_helper = 0 - - counter = 0 - gtx_blind_counter = 0 - gtx_blinded = False - - main_blind_counter = 0 - main_blinded = False - - wait_gtx = True - wait_main = True - - # firmware values - FW_gtx_tag = 0 - FW_main_tag = 0 - - period_err = last_period_err = 0 - h_prop = h_integrator = h_derivative = 0 - h_adpll = base_adpll - h_i2c_active_index = 0 - period_colr_arm = h_i2c_active = False - - phase_err = last_phase_err = 0 - m_prop = m_integrator = m_derivative = 0 - m_adpll = base_adpll - m_i2c_active_index = 0 - phase_colr_arm = m_i2c_active = False - - adpll_max = 8161512 - - def clip(n, minn, maxn): - return max(min(maxn, n), minn) - - for i, t in enumerate(time): - h_phase += 360 * helper_freq * (timestep + h_jitter[i]) - helper[i] = square(h_phase) - - m_phase += 360 * main_freq * (timestep + m_jitter[i]) - main[i] = square(m_phase) - - gtx_phase += 360 * gtx_freq * (timestep + gtx_jitter[i]) - gtx[i] = square(gtx_phase) - - if not last_helper and helper[i]: - gtx_FF, gtx_beating[i] = DDMTD(gtx[i], last_gtx_FF) - main_FF, main_beating[i] = DDMTD(main[i], last_main_FF) - - gtx_tag, gtx_ready, gtx_blind_counter, gtx_blinded = Deglitcher( - gtx_beating[i], - gtx_tag, - gtx_ready, - gtx_blind_counter, - gtx_blinded, - blind_period, - last_gtx_beat, - last_gtx_tag, - counter, - ) - - main_tag, main_ready, main_blind_counter, main_blinded = Deglitcher( - main_beating[i], - main_tag, - main_ready, - main_blind_counter, - main_blinded, - blind_period, - last_main_beat, - last_main_tag, - counter, - ) - - ( - phase_collector_r, - wait_gtx, - wait_main, - colr_gtx_tag, - colr_main_tag, - phase_collector_state, - ) = phase_collector_FSM( - gtx_ready, - main_ready, - gtx_tag, - main_tag, - wait_gtx, - wait_main, - colr_gtx_tag, - colr_main_tag, - phase_collector_state, - ) - - if phase_collector_r: - FW_gtx_tag = colr_gtx_tag - FW_main_tag = colr_main_tag - - ( - period_collector_r, - colr_last_gtx_tag, - beating_period, - period_collector_state, - ) = period_collector_FSM( - gtx_ready, gtx_tag, colr_last_gtx_tag, period_collector_state - ) - - counter += 1 - - last_gtx_beat = gtx_beating[i] - last_gtx_FF = gtx_FF - last_gtx_tag = gtx_tag - last_main_beat = main_beating[i] - last_main_FF = main_FF - last_main_tag = main_tag - - else: - gtx_beating[i] = last_gtx_beat - gtx_FF = last_gtx_FF - gtx_tag = last_gtx_tag - - main_beating[i] = last_main_beat - main_FF = last_main_FF - main_tag = last_main_tag - - if i > start_up_delay: - if i % adpll_write_period == 0: - period_colr_arm = phase_colr_arm = True - - # Firmware filters - - if period_colr_arm and period_collector_r: - period_colr_arm = False - period_err = N - beating_period - - if helper_pll: - h_prop = period_err * h_KP - h_integrator += period_err * h_KI - h_derivative = (period_err - last_period_err) * h_KD - - h_adpll = clip( - int(base_adpll + h_prop + h_integrator + h_derivative), - -adpll_max, - adpll_max, - ) - last_period_err = period_err - h_i2c_active_index = i - h_i2c_active = True - - if phase_colr_arm and phase_collector_r: - phase_colr_arm = False - tag_diff = (FW_main_tag - FW_gtx_tag) % N - if tag_diff > N / 2: - phase_err = tag_diff - N - else: - phase_err = tag_diff - - if main_pll: - m_prop = phase_err * m_KP - m_integrator += phase_err * m_KI - m_derivative = (phase_err - last_phase_err) * m_KD - - m_adpll = clip( - int(base_adpll + m_prop + m_integrator + m_derivative), - -adpll_max, - adpll_max, - ) - last_phase_err = phase_err - - m_i2c_active_index = i - m_i2c_active = True - - # i2c communication delay - - if ( - h_i2c_active - and i >= i2c_comm_delay + dcxo_settling_delay + h_i2c_active_index - ): - helper_freq = ( - dcxo_freq * (1 + h_adpll * 0.0001164 / 1_000_000) * ((N - 1) / N) - ) - h_i2c_active = False - - if ( - m_i2c_active - and i >= i2c_comm_delay + dcxo_settling_delay + m_i2c_active_index - ): - main_freq = dcxo_freq * (1 + m_adpll * 0.0001164 / 1_000_000) - m_i2c_active = False - - last_helper = helper[i] - - # Data - period_err_arr[i] = period_err - phase_err_arr[i] = phase_err - helper_adpll_arr[i] = h_adpll - main_adpll_arr[i] = m_adpll - helperfreq[i] = helper_freq - mainfreq[i] = main_freq - - return ( - period_err_arr, - phase_err_arr, - helper_adpll_arr, - main_adpll_arr, - gtx_beating, - main_beating, - gtx, - helper, - main, - helperfreq, - mainfreq, - ) - - -@njit -def DDMTD(sig_in, last_FF): - return sig_in, last_FF - - -@njit -def Deglitcher( - beating, - t_out, - t_ready, - blind_counter, - blinded, - blind_period, - last_beat, - last_tag, - counter, -): - if blind_counter == 0 and beating and not last_beat: # rising - t_out = counter - t_ready = 1 - blinded = True - else: - t_out = last_tag - t_ready = 0 - - if beating: - blind_counter = blind_period - 1 - - if blind_counter != 0: - blind_counter -= 1 - - return t_out, t_ready, blind_counter, blinded - - -@njit -def phase_collector_FSM( - g_tag_r, - m_tag_r, - gtx_tag, - main_tag, - wait_gtx, - wait_main, - colr_gtx_tag, - colr_main_tag, - FSM_state, -): - collector_r = 0 - - match FSM_state: - case 0: # IDEL - if g_tag_r and m_tag_r: - colr_gtx_tag = gtx_tag - colr_main_tag = main_tag - FSM_state = 3 # OUTPUT - - elif g_tag_r: - colr_gtx_tag = gtx_tag - wait_main = True - FSM_state = 2 # WAITMAIN - - elif m_tag_r: - colr_main_tag = main_tag - wait_gtx = True - FSM_state = 1 # WAITGTX - case 1: # WAITGTX - if g_tag_r: - colr_gtx_tag = gtx_tag - FSM_state = 3 # OUTPUT - case 2: # WAITMAIN - if m_tag_r: - colr_main_tag = main_tag - FSM_state = 3 # OUTPUT - case 3: # OUTPUT - wait_gtx = wait_main = False - collector_r = 1 - FSM_state = 0 - - return collector_r, wait_gtx, wait_main, colr_gtx_tag, colr_main_tag, FSM_state - - -@njit -def period_collector_FSM(g_tag_r, gtx_tag, colr_last_gtx_tag, FSM_state): - collector_r = 0 - beating_period = 0 - match FSM_state: - case 0: # IDEL - if g_tag_r: - colr_last_gtx_tag = gtx_tag - FSM_state = 1 - case 1: # OUTPUT - if g_tag_r: - beating_period = gtx_tag - colr_last_gtx_tag - collector_r = 1 - FSM_state = 0 # IDEL - - return collector_r, colr_last_gtx_tag, beating_period, FSM_state diff --git a/src/wrpll_simulation/timesim.py b/src/wrpll_simulation/timesim.py new file mode 100644 index 0000000..13c4251 --- /dev/null +++ b/src/wrpll_simulation/timesim.py @@ -0,0 +1,127 @@ +import typing as tp +import numpy as np +from numba.core.types import * +from numba.experimental import jitclass + +from wrpll_simulation.config import Timesim_Config +from wrpll_simulation.timesim_node import * + + +@jitclass +class WRPLL_Timesim(object): + cfg: Timesim_Config + + time: float32[:] + + freq_diff: float32[:] + phase_diff: float32[:] + helper_error: int16[:] + main_error: int16[:] + + def __init__(self, cfg: Timesim_Config, rng: tp.Generator): + # subclass/inheritance is not supported by numba jitclass + # https://github.com/numba/numba/issues/1694 + self.cfg = cfg + + sim_length = cfg.sim_length + stop_time = cfg.timestep_size * sim_length + self.time = np.linspace(0, stop_time, sim_length).astype(np.float32) + + self.freq_diff = np.zeros(sim_length, dtype=np.float32) + self.phase_diff = np.zeros(sim_length, dtype=np.float32) + self.helper_error = np.zeros(sim_length, dtype=np.int16) + self.main_error = np.zeros(sim_length, dtype=np.int16) + + # __post_init__ is not supported by numba jitclass + # https://github.com/numba/numba/issues/4037 + self.simulate(rng) + + def simulate(self, rng: tp.Generator): + cfg = self.cfg + + timestep_size = cfg.timestep_size + irq_delay = self.seconds_to_step(cfg.irq_delay) + i2c_comm_delay = self.seconds_to_step(cfg.i2c_comm_delay) + + # simulation node + gtx = Phase_Accumlator(cfg.gtx_init_freq, cfg.gtx_init_phase) + helper = Phase_Accumlator(cfg.helper_init_freq, cfg.helper_init_phase) + main = Phase_Accumlator(cfg.main_init_freq, cfg.main_init_phase) + + ddmtd_gtx = DDMTD(cfg.blind_period) + ddmtd_main = DDMTD(cfg.blind_period) + gtx_tag_irq = EventManager_IRQ() + main_tag_irq = EventManager_IRQ() + + tag_collector = Tag_Collector(cfg.beating_period) + helper_PLL = PI_loop(cfg.helper_PI, cfg.helper_init_freq, 0, cfg.adpll_limit) + main_PLL = PI_loop(cfg.main_PI, cfg.main_init_freq, 0, cfg.adpll_limit) + + counter = 0 + + print("Running...") + for i in range(cfg.sim_length): + if cfg.has_jitter: + gtx.update(timestep_size + rng.normal(0, cfg.gtx_jitter)) + helper.update(timestep_size + rng.normal(0, cfg.dcxo_jitter)) + main.update(timestep_size + rng.normal(0, cfg.dcxo_jitter)) + else: + gtx.update(timestep_size) + helper.update(timestep_size) + main.update(timestep_size) + + # GATEWARE + if helper.is_rising(): + ddmtd_gtx.sync_update(gtx.o, counter) + ddmtd_main.sync_update(main.o, counter) + + # for clock domain crossing + gtx_tag_irq.multireg(ddmtd_gtx.tag_ready, ddmtd_gtx.tag) + main_tag_irq.multireg(ddmtd_main.tag_ready, ddmtd_main.tag) + + counter += 1 + + if main.is_rising(): + # Generate interrupt request + gtx_tag_irq.sync_update(i) + main_tag_irq.sync_update(i) + + # FIRMWARE + if gtx_tag_irq.is_due(i, irq_delay): + tag_collector.collect_gtx_tag(gtx_tag_irq.tag_csr) + helper_PLL.update(i, tag_collector.get_period_error()) + + if tag_collector.is_phase_error_ready(): + tag_collector.set_phase_error_ready(False) + main_PLL.update(i, tag_collector.get_phase_error()) + + if main_tag_irq.is_due(i, irq_delay): + tag_collector.collect_main_tag(main_tag_irq.tag_csr) + + if tag_collector.is_phase_error_ready(): + tag_collector.set_phase_error_ready(False) + main_PLL.update(i, tag_collector.get_phase_error()) + + if helper_PLL.i2c_is_due(i, i2c_comm_delay): + helper.set_freq(helper_PLL.get_new_freq()) + + if main_PLL.i2c_is_due(i, i2c_comm_delay): + main.set_freq(main_PLL.get_new_freq()) + + # Data Logging + self.freq_diff[i] = np.float32(main.freq - gtx.freq) + self.phase_diff[i] = np.float32((main.phase - gtx.phase) % 360) + if self.phase_diff[i] > 180: + self.phase_diff[i] -= 360 + + if helper_PLL.i2c_is_due(i, i2c_comm_delay): + self.helper_error[i] = tag_collector.get_period_error() + elif i > 0: + self.helper_error[i] = self.helper_error[i - 1] + if main_PLL.i2c_is_due(i, i2c_comm_delay): + self.main_error[i] = tag_collector.get_phase_error() + elif i > 0: + self.main_error[i] = self.main_error[i - 1] + + def seconds_to_step(self, seconds: float64): + return int(seconds / self.cfg.timestep_size) diff --git a/src/wrpll_simulation/timesim_node.py b/src/wrpll_simulation/timesim_node.py new file mode 100644 index 0000000..7daec75 --- /dev/null +++ b/src/wrpll_simulation/timesim_node.py @@ -0,0 +1,220 @@ +from numba.core.types import * +from numba.experimental import jitclass + +from wrpll_simulation.config import PI_Config + + +@jitclass +class Phase_Accumlator(object): + last_o: int8 + o: int8 + freq: float64 + phase: float64 + + def __init__(self, freq: float64, phase: float64): + self.last_o = 0 + self.o = 0 + self.freq = freq + self.phase = phase + + def update(self, time_increment: float64): + self.last_o = self.o + + self.phase = (self.phase + 360 * self.freq * time_increment) % 360 + + # square wave function + if self.phase < 180: + self.o = 0 + else: + self.o = 1 + + def set_freq(self, freq: float64): + self.freq = freq + + def set_phase(self, phase: float64): + self.phase = phase + + def is_rising(self) -> bool: + return not self.last_o and self.o + + +# GATEWARE +@jitclass +class DDMTD(object): + FF: int8 + beating: int8 + last_beating: int8 + + tag: int32 + tag_ready: int8 + blind_counter: int16 + blind_period: int16 + + def __init__(self, blind_period: int16): + # back to back Flip Flop + self.FF = 0 + self.beating = 0 + self.last_beating = 0 + + # deglticher + self.tag = 0 + self.tag_ready = 0 + self.blind_counter = 0 + self.blind_period = blind_period + + def sync_update(self, D_in: int8, counter: int32): + self.last_beating = self.beating + + # FF shifting + self.beating = self.FF + self.FF = D_in + + self.deglitcher_first_edge(counter) + + def deglitcher_first_edge(self, counter: int32): + if self.blind_counter == 0 and self.beating_is_rising(): + self.tag = counter + self.tag_ready = 1 + else: + self.tag_ready = 0 + + if self.beating: + self.blind_counter = self.blind_period - 1 + + if self.blind_counter != 0: + self.blind_counter -= 1 + + def beating_is_rising(self) -> bool: + return not self.last_beating and self.beating + + +@jitclass +class EventManager_IRQ(object): + trigger: int8 + tag_csr: int32 + + trigger_index: int64 + + def __init__(self): + self.trigger = 0 + self.tag_csr = 0 + + # for simulating delay + self.trigger_index = 0 + + def sync_update(self, index): + if self.trigger: + self.trigger = 0 + self.trigger_index = index + + def multireg(self, trigger: int8, tag: int32): + if trigger: + self.trigger = 1 + self.tag_csr = tag + + def is_due(self, index, delay): + return index - self.trigger_index == delay + + +# FIRMWARE +@jitclass +class Tag_Collector(object): + setpt_beating_period: int64 + last_gtx: int32 + + gtx_tag_ready: bool + gtx_tag: int32 + main_tag_ready: bool + main_tag: int32 + + def __init__(self, setpt_beating_period: int64): + self.setpt_beating_period = setpt_beating_period + self.last_gtx = 0 + + # for main PLL + self.gtx_tag_ready = False + self.gtx_tag = 0 + self.main_tag_ready = False + self.main_tag = 0 + + def collect_gtx_tag(self, tag: int32): + self.last_gtx = self.gtx + self.gtx_tag = tag + self.gtx_tag_ready = True + + def collect_main_tag(self, tag: int32): + self.main_tag = tag + self.main_tag_ready = True + + def get_period_error(self) -> int32: + return self.set_phase_error_ready - (self.gtx - self.last_gtx) + + def get_phase_error(self) -> int32: + # tag_diff = main_tag(n) - gtx_tag(n) + tag_diff = (self.main_tag - self.gtx_tag) % self.setpt_beating_period + + # mapping tags from [0, 2π] -> [-π, π] + if tag_diff > self.setpt_beating_period / 2: + return tag_diff - self.setpt_beating_period + + return tag_diff + + def set_phase_error_ready(self, ready: bool): + self.main_tag_ready = ready + self.gtx_tag_ready = ready + + def is_phase_error_ready(self) -> bool: + return self.main_tag_ready and self.gtx_tag_ready + + +@jitclass +class PI_loop(object): + KP: int64 + KI: int64 + integrator: int64 + + center_freq: float64 + adpll: int32 + base_adpll: int32 + adpll_limit: int32 + + i2c_transfer_index: int64 + + def __init__( + self, + PI_Conifg: PI_Config, + center_freq: float64, + base_adpll: int32, + adpll_limit: int32, + ): + # PI controller + self.KP = PI_Conifg.KP + self.KI = PI_Conifg.KI + self.integrator = 0 + + # ADPLL calcuation + self.center_freq = center_freq + self.adpll = base_adpll + self.base_adpll = base_adpll + self.adpll_limit = adpll_limit + + # for simulating delay + self.i2c_transfer_index = 0 + + def update(self, index, tag_error: int32): + self.i2c_transfer_index = index + self.adpll = self.get_adpll(tag_error) + + def get_adpll(self, tag_error: int32) -> int32: + prop = tag_error * self.KP + self.integrator += tag_error * self.KI + return self.cramp_adpll(self.base_adpll + int(prop + self.integrator)) + + def get_new_freq(self): + return self.center_freq * (1 + self.adpll * 0.0001164 / 1_000_000) + + def cramp_adpll(self, adpll: int32) -> int32: + return max(min(self.adpll_limit, adpll), -self.adpll_limit) + + def i2c_is_due(self, index, delay): + return index - self.i2c_transfer_index == delay diff --git a/src/wrpll_simulation/wave_gen.py b/src/wrpll_simulation/wave_gen.py deleted file mode 100644 index aa30f18..0000000 --- a/src/wrpll_simulation/wave_gen.py +++ /dev/null @@ -1,27 +0,0 @@ -import numpy as np -from numba import njit - - -def gussian_jitter(RMS_jitter, size, seed=None): - return np.random.default_rng(seed).normal(0, RMS_jitter / 2, size) - - -@njit(fastmath=True) -def square_with_jitter(time, freq, jitter): - n = len(time) - wave = np.empty(n) - timestep = time[1] - time[0] - - phase = 0.0 - for i in range(n): - phase += 360 * freq * (timestep + jitter[i]) - wave[i] = square(phase) - return wave - - -@njit -def square(x): - if np.mod(x, 360) < 180: - return 0 - else: - return 1 diff --git a/src/wrpll_simulation/wrpll.py b/src/wrpll_simulation/wrpll.py deleted file mode 100644 index 9a2cfe0..0000000 --- a/src/wrpll_simulation/wrpll.py +++ /dev/null @@ -1,117 +0,0 @@ -import numpy as np -from wrpll_simulation.sim import simulation_jit -from wrpll_simulation.wave_gen import gussian_jitter - - -class WRPLL_simulator: - def __init__( - self, - timestep, - total_steps, - sim_mode, - helper_filter, - main_filter, - gtx_freq, - adpll_write_period, - start_up_delay, - i2c_comm_delay=85.6e-6, - dcxo_settling_delay=100e-6, - gtx_jitter=200e-15, - dcxo_freq=125_000_000, - dcxo_jitter=95e-15, - freq_acquisition_error=100, - N=4069, - blind_period=128, - helper_init_freq=None, - seed=None, - ): - self.time = np.linspace(0, timestep * total_steps, total_steps) - self.sim_mode = sim_mode - self.h_KP = helper_filter["KP"] - self.h_KI = helper_filter["KI"] - self.h_KD = helper_filter["KD"] - self.m_KP = main_filter["KP"] - self.m_KI = main_filter["KI"] - self.m_KD = main_filter["KD"] - - # init condition - self.gtx_freq = gtx_freq - self.gtx_jitter = gussian_jitter(gtx_jitter, len(self.time), seed) - - self.dcxo_freq = dcxo_freq - self.h_jitter = gussian_jitter(dcxo_jitter, len(self.time), seed) - self.m_jitter = gussian_jitter(dcxo_jitter, len(self.time), seed) - self.N = N - self.helper_init_freq = helper_init_freq - - # freq_acquisition() error - freq_diff = ( - gtx_freq - - dcxo_freq - + np.random.default_rng(seed).uniform( - -freq_acquisition_error, freq_acquisition_error - ) - ) - self.base_adpll = int(freq_diff * (1 / dcxo_freq) * (1e6 / 0.0001164)) - - # sim config - self.i2c_comm_delay = int(i2c_comm_delay / timestep) - self.dcxo_settling_delay = int(dcxo_settling_delay / timestep) - self.blind_period = blind_period - self.adpll_write_period = int(adpll_write_period / timestep) - self.start_up_delay = int(start_up_delay / timestep) - - if type(self.sim_mode) is not str: - raise ValueError(f"pll_type {type(self.sim_mode)} is not a string") - - self.helper_pll = self.main_pll = False - if self.sim_mode.lower() == "both": - self.helper_pll = self.main_pll = True - elif self.sim_mode.lower() == "helper_pll": - self.helper_pll = True - elif self.sim_mode.lower() == "main_pll": - if self.helper_init_freq is None: - raise ValueError("main pll mode need to set a helper frequency") - self.main_pll = True - else: - raise ValueError("sim_mode is not helper_pll nor main_pll") - - def run(self): - print("running simulation...") - ( - self.period_err, - self.phase_err, - self.helper_adpll, - self.main_adpll, - self.gtx_beating, - self.main_beating, - self.gtx, - self.helper, - self.main, - self.helperfreq, - self.mainfreq, - ) = simulation_jit( - self.time, - self.gtx_freq, - self.gtx_jitter, - self.helper_pll, - self.main_pll, - self.h_KP, - self.h_KI, - self.h_KD, - self.m_KP, - self.m_KI, - self.m_KD, - self.dcxo_freq, - self.h_jitter, - self.m_jitter, - self.base_adpll, - self.N, - self.adpll_write_period, - self.i2c_comm_delay, - self.dcxo_settling_delay, - self.blind_period, - self.start_up_delay, - self.helper_init_freq, - ) - print("Done!")