diff --git a/README.md b/README.md index 3160365..b0d139b 100644 --- a/README.md +++ b/README.md @@ -49,20 +49,3 @@ source .venv/bin/activate ## Limitation As the simulation is not cycle nor delay accurate, there will be more glitches than the hardware implementation - -### Helper PLL glitches (remedies are added to follow hardware behavior) - -- Cycle slipping issue will appear as $|\Delta{period}| \sim N$ - - During hardware testing, slipping issue is not common - - It's recommended to turn cycle_slip_comp ON to reduce slipping and have a more accurate simulation - -![cycle_slip](img/cycle_slipping.png) - - -- Deglitcher fail issue will appear as $|\Delta{period}| \sim N/2$ - - There are no such issue for hardware - - It's recommended to set blind_period higher than the hardware setting (around 300 is sufficient) - -![deglitch_fail](img/deglitch_fail.png) - - diff --git a/img/cycle_slipping.png b/img/cycle_slipping.png deleted file mode 100644 index 426c8d2..0000000 Binary files a/img/cycle_slipping.png and /dev/null differ diff --git a/img/deglitch_fail.png b/img/deglitch_fail.png deleted file mode 100644 index ed08641..0000000 Binary files a/img/deglitch_fail.png and /dev/null differ diff --git a/src/sim.py b/src/sim.py index e5c804f..7cfb32b 100644 --- a/src/sim.py +++ b/src/sim.py @@ -6,7 +6,8 @@ from wave_gen import square @njit def simulation_jit( time, - gtx, + gtx_freq, + gtx_jitter, helper_pll, main_pll, h_KP, @@ -16,13 +17,13 @@ def simulation_jit( m_KI, m_KD, dcxo_freq, - dcxo_jitter_SD, + h_jitter, + m_jitter, base_adpll, N, adpll_write_period, blind_period, start_up_delay, - cycle_slip_comp=True, helper_init_freq=0 ): @@ -30,6 +31,7 @@ def simulation_jit( main = np.zeros(arr_len, dtype=np.int8) helper = np.zeros(arr_len, dtype=np.int8) + gtx = np.zeros(arr_len, dtype=np.int8) gtx_tag = 0 gtx_ready = 0 @@ -40,7 +42,6 @@ def simulation_jit( main_ready = 0 main_FF = 0 main_beating = np.zeros(arr_len, dtype=np.int8) - collector_r = 0 phase_err_arr = np.zeros(arr_len, dtype=np.int16) period_err_arr = np.zeros(arr_len, dtype=np.int16) @@ -50,15 +51,20 @@ def simulation_jit( helperfreq = np.zeros(arr_len, dtype=np.int32) mainfreq = np.zeros(arr_len, dtype=np.int32) - # initial condition - main_init_offset = (np.random.uniform(0, 360)) / (360 * dcxo_freq) - helper_init_offset = (np.random.uniform(0, 360)) / (360 * dcxo_freq) + phase_collector_r = 0 + colr_gtx_tag = colr_main_tag = 0 + phase_collector_state = 0 - # intermediate values - helper_jitter = np.random.normal(0, dcxo_jitter_SD) - main_jitter = np.random.normal(0, dcxo_jitter_SD) - helper_cycle_num = 0 - main_cycle_num = 0 + period_collector_r = 0 + colr_last_gtx_tag = 0 + beating_period = 0 + period_collector_state = 0 + + # initial condition + timestep = time[1] - time[0] + gtx_phase = 0 + h_phase = np.random.uniform(0, 360) + m_phase = np.random.uniform(0, 360) if main_pll and not helper_pll: helper_freq = helper_init_freq @@ -66,12 +72,10 @@ def simulation_jit( helper_freq = dcxo_freq * (1 + base_adpll * 0.0001164 / 1_000_000) * ((N-1) / N) main_freq = dcxo_freq * (1 + base_adpll * 0.0001164 / 1_000_000) - FSM_state = 0 - coll_gtx_tag = coll_main_tag = 0 last_gtx_tag = last_gtx_FF = last_gtx_beat = 0 last_main_tag = last_main_FF = last_main_beat = 0 - last_helper = last_main = 0 + last_helper = 0 counter = 0 gtx_blind_counter = 0 @@ -84,38 +88,32 @@ def simulation_jit( wait_main = True # firmware values - FW_gtx_tag = FW_last_gtx_tag = 0 + FW_gtx_tag = 0 FW_main_tag = 0 period_err = last_period_err = 0 h_prop = h_integrator = h_derivative = 0 h_adpll = base_adpll - h_i2c_active = False + period_colr_arm = h_i2c_active = False phase_err = last_phase_err = 0 m_prop = m_integrator = m_derivative = 0 m_adpll = base_adpll - m_i2c_active = False + phase_colr_arm = m_i2c_active = False - adpll_active = False adpll_max = 8161512 def clip(n, minn, maxn): return max(min(maxn, n), minn) for i, t in enumerate(time): - helper[i], helper_cycle_num, helper_jitter = square( - t + helper_init_offset, helper_freq, dcxo_jitter_SD, helper_jitter, helper_cycle_num) - main[i], main_cycle_num, main_jitter = square( - t + main_init_offset, main_freq, dcxo_jitter_SD, main_jitter, main_cycle_num) + h_phase += 360 * helper_freq * (timestep + h_jitter[i]) + helper[i] = square(h_phase) - # continuous glitchless output, assume very small frequency change - if h_i2c_active and helper[i]: - helper_freq = dcxo_freq * (1 + h_adpll * 0.0001164 / 1_000_000) * ((N-1) / N) - h_i2c_active = False + m_phase += 360 * main_freq * (timestep + m_jitter[i]) + main[i] = square(m_phase) - if m_i2c_active and main[i]: - main_freq = dcxo_freq * (1 + m_adpll * 0.0001164 / 1_000_000) - m_i2c_active = False + gtx_phase += 360 * gtx_freq * (timestep + gtx_jitter[i]) + gtx[i] = square(gtx_phase) if not last_helper and helper[i]: @@ -130,12 +128,15 @@ def simulation_jit( main_blind_counter, main_blinded, blind_period, last_main_beat, last_main_tag, counter) - collector_r, wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state = Collector_FSM(gtx_ready, main_ready, gtx_tag, main_tag, - wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state) + phase_collector_r, wait_gtx, wait_main, colr_gtx_tag, colr_main_tag, phase_collector_state = phase_collector_FSM(gtx_ready, main_ready, gtx_tag, main_tag, + wait_gtx, wait_main, colr_gtx_tag, colr_main_tag, phase_collector_state) - if collector_r: - FW_gtx_tag = coll_gtx_tag - FW_main_tag = coll_main_tag + if phase_collector_r: + FW_gtx_tag = colr_gtx_tag + FW_main_tag = colr_main_tag + + period_collector_r, colr_last_gtx_tag, beating_period, period_collector_state = period_collector_FSM(gtx_ready, gtx_tag, + colr_last_gtx_tag, period_collector_state) counter += 1 @@ -158,20 +159,11 @@ def simulation_jit( if i > start_up_delay: # Firmware filters - if adpll_active and collector_r: - if cycle_slip_comp: - period = FW_gtx_tag - FW_last_gtx_tag - if period > 3 * N/2: - period = period - N - period_err = N - period - else: - period_err = (N - (FW_gtx_tag - FW_last_gtx_tag)) - tag_diff = ((FW_main_tag - FW_gtx_tag) % N) - if tag_diff > N/2: - phase_err = tag_diff - N - else: - phase_err = tag_diff + if period_colr_arm and period_collector_r: + + period_colr_arm = False + period_err = N - beating_period if helper_pll: h_prop = period_err * h_KP @@ -180,8 +172,18 @@ def simulation_jit( h_adpll = clip(int(base_adpll + h_prop + h_integrator + h_derivative), -adpll_max, adpll_max) last_period_err = period_err + helper_freq = dcxo_freq * (1 + h_adpll * 0.0001164 / 1_000_000) * ((N-1) / N) h_i2c_active = True + if phase_colr_arm and phase_collector_r: + + phase_colr_arm = False + tag_diff = ((FW_main_tag - FW_gtx_tag) % N) + if tag_diff > N/2: + phase_err = tag_diff - N + else: + phase_err = tag_diff + if main_pll: m_prop = phase_err * m_KP m_integrator += phase_err * m_KI @@ -189,16 +191,13 @@ def simulation_jit( m_adpll = clip(int(base_adpll + m_prop + m_integrator + m_derivative), -adpll_max, adpll_max) last_phase_err = phase_err + main_freq = dcxo_freq * (1 + m_adpll * 0.0001164 / 1_000_000) m_i2c_active = True - adpll_active = False - if i % adpll_write_period == 0: - adpll_active = True - FW_last_gtx_tag = FW_gtx_tag + period_colr_arm = phase_colr_arm = True last_helper = helper[i] - last_main = main[i] # Data period_err_arr[i] = period_err @@ -208,7 +207,7 @@ def simulation_jit( helperfreq[i] = helper_freq mainfreq[i] = main_freq - return period_err_arr, phase_err_arr, helper_adpll_arr, main_adpll_arr, gtx_beating, main_beating, helper, main, helperfreq, mainfreq + return period_err_arr, phase_err_arr, helper_adpll_arr, main_adpll_arr, gtx_beating, main_beating, gtx, helper, main, helperfreq, mainfreq @njit @@ -237,37 +236,55 @@ def Deglitcher(beating, t_out, t_ready, blind_counter, blinded, blind_period, la @njit -def Collector_FSM(g_tag_r, m_tag_r, gtx_tag, main_tag, wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state): +def phase_collector_FSM(g_tag_r, m_tag_r, gtx_tag, main_tag, wait_gtx, wait_main, colr_gtx_tag, colr_main_tag, FSM_state): collector_r = 0 match FSM_state: case 0: # IDEL if g_tag_r and m_tag_r: - coll_gtx_tag = gtx_tag - coll_main_tag = main_tag + colr_gtx_tag = gtx_tag + colr_main_tag = main_tag FSM_state = 3 # OUTPUT elif g_tag_r: - coll_gtx_tag = gtx_tag + colr_gtx_tag = gtx_tag wait_main = True FSM_state = 2 # WAITMAIN elif m_tag_r: - coll_main_tag = main_tag + colr_main_tag = main_tag wait_gtx = True FSM_state = 1 # WAITGTX case 1: # WAITGTX if g_tag_r: - coll_gtx_tag = gtx_tag + colr_gtx_tag = gtx_tag FSM_state = 3 # OUTPUT case 2: # WAITMAIN if m_tag_r: - coll_main_tag = main_tag + colr_main_tag = main_tag FSM_state = 3 # OUTPUT case 3: # OUTPUT wait_gtx = wait_main = False collector_r = 1 FSM_state = 0 - return collector_r, wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state + return collector_r, wait_gtx, wait_main, colr_gtx_tag, colr_main_tag, FSM_state + + +@njit +def period_collector_FSM(g_tag_r, gtx_tag, colr_last_gtx_tag, FSM_state): + collector_r = 0 + beating_period = 0 + match FSM_state: + case 0: # IDEL + if g_tag_r: + colr_last_gtx_tag = gtx_tag + FSM_state = 1 + case 1: # OUTPUT + if g_tag_r: + beating_period = gtx_tag - colr_last_gtx_tag + collector_r = 1 + FSM_state = 0 # IDEL + + return collector_r, colr_last_gtx_tag, beating_period, FSM_state diff --git a/src/wave_gen.py b/src/wave_gen.py index ece4fa2..f67ed6c 100644 --- a/src/wave_gen.py +++ b/src/wave_gen.py @@ -1,69 +1,28 @@ import numpy as np +import numba as nb from numba import njit -@njit -def square(time, freq, jitter_SD, jitter, cycle_num): - """ - - A scipy like square wave with jitter - Parameters - ---------- - time : timestamp, in seconds - - freq : frequency in Hz - - jitter_SD : standard deviation of jitter, in seconds - - jitter : gaussian noise with `jitter_SD` as its standard deviation - - cycle_num : time // period = cycle_num - ---------- - - """ - - period = 1/(freq) - quarter = (period / 4) - out = 0 - - # A T/4 shift is applied to the following to match scipy square fn - # ┌──────┐ - # ───┘ └──── - # T/4 3T/4 - nth_cycle, t = np.divmod(time + period / 4, period) - if t >= (quarter + jitter) and t <= (3*quarter + jitter): - out = 1 - else: - out = 0 - - # update jitter every cycle - if nth_cycle != cycle_num: - jitter = np.random.normal(0, jitter_SD) - cycle_num = nth_cycle - - return out, cycle_num, jitter +def white_noise(low, high, size, seed=None): + return np.random.default_rng(seed).uniform(low, high, size) -@njit -def square_arr(time, freq, jitter_SD): - """ - - A scipy like square wave with jitter - Parameters - ---------- - time : numpy array - - freq : frequency in Hz - - jitter_SD : standard deviation of jitter, in seconds - ---------- - - """ - wave = np.zeros(len(time)) - jitter = np.random.normal(0, jitter_SD) - cycle_num = 0 - - for i, t in enumerate(time): - wave[i], cycle_num, jitter = square(t, freq, jitter_SD, jitter, cycle_num) +@njit(fastmath=True) +def square_with_jitter(time, freq, jitter): + n = len(time) + wave = np.empty(n) + timestep = time[1] - time[0] + phase = 0. + for i in range(n): + phase += 360 * freq * (timestep + jitter[i]) + wave[i] = square(phase) return wave + + +@njit +def square(x): + if np.mod(x, 360) < 180: + return 0 + else: + return 1 diff --git a/src/wrpll.py b/src/wrpll.py index 1ff83d4..97d95b4 100644 --- a/src/wrpll.py +++ b/src/wrpll.py @@ -1,5 +1,5 @@ import numpy as np -from wave_gen import square_arr +from wave_gen import white_noise from sim import simulation_jit @@ -12,16 +12,16 @@ class WRPLL_simulator(): helper_filter, main_filter, gtx_freq, - gtx_jitter_SD, + gtx_jitter, dcxo_freq, - dcxo_jitter_SD, + dcxo_jitter, freq_acquisition_SD, N, adpll_write_period, blind_period, start_up_delay, - cycle_slip_comp, - helper_init_freq=None + helper_init_freq=None, + seed=None ): self.time = time @@ -34,13 +34,15 @@ class WRPLL_simulator(): self.m_KD = main_filter["KD"] # init condition + self.gtx_freq = gtx_freq + self.gtx_jitter = white_noise(-gtx_jitter, gtx_jitter, len(time), seed) + self.dcxo_freq = dcxo_freq - self.dcxo_jitter_SD = dcxo_jitter_SD + self.h_jitter = white_noise(-dcxo_jitter, dcxo_jitter, len(time), seed) + self.m_jitter = white_noise(-dcxo_jitter, dcxo_jitter, len(time), seed) self.N = N self.helper_init_freq = helper_init_freq - self.gtx = square_arr(time, gtx_freq, gtx_jitter_SD) - # freq_acquisition() error freq_diff = gtx_freq - dcxo_freq + np.random.normal(0, freq_acquisition_SD) self.base_adpll = int(freq_diff * (1 / dcxo_freq) * (1e6 / 0.0001164)) @@ -49,7 +51,6 @@ class WRPLL_simulator(): self.adpll_write_period = adpll_write_period self.blind_period = blind_period self.start_up_delay = start_up_delay - self.cycle_slip_comp = cycle_slip_comp if type(self.sim_mode) is not str: raise ValueError(f"pll_type {type(self.sim_mode)} is not a string") @@ -68,9 +69,10 @@ class WRPLL_simulator(): def run(self): print("running simulation...") - self.period_err, self.phase_err, self.helper_adpll, self.main_adpll, self.gtx_beating, self.main_beating, self.helper, self.main, self.helperfreq, self.mainfreq = simulation_jit( + self.period_err, self.phase_err, self.helper_adpll, self.main_adpll, self.gtx_beating, self.main_beating, self.gtx, self.helper, self.main, self.helperfreq, self.mainfreq = simulation_jit( self.time, - self.gtx, + self.gtx_freq, + self.gtx_jitter, self.helper_pll, self.main_pll, self.h_KP, @@ -80,13 +82,13 @@ class WRPLL_simulator(): self.m_KI, self.m_KD, self.dcxo_freq, - self.dcxo_jitter_SD, + self.h_jitter, + self.m_jitter, self.base_adpll, self.N, self.adpll_write_period, self.blind_period, self.start_up_delay, - self.cycle_slip_comp, self.helper_init_freq ) print("Done!")