WRPLL time domain simulation

wave_gen: add square wave with jitter generator
sim: add WRPLL gateware and firmware simulation
sim: add PID for main and helper PLL
sim: add options to choose PLL modes
sim: add cycle slip compensation for helper PLL
sim: optimize execution time with numba jit
This commit is contained in:
morgan 2023-12-05 13:23:23 +08:00
parent 72f47668c9
commit cbc9772efa
2 changed files with 343 additions and 0 deletions

274
src/sim.py Normal file
View File

@ -0,0 +1,274 @@
import numpy as np
from numba import jit
from wave_gen import square
@jit(nopython=True)
def simulation_jit(
time,
gtx,
helper_pll,
main_pll,
h_KP,
h_KI,
h_KD,
m_KP,
m_KI,
m_KD,
dcxo_freq,
dcxo_jitter_SD,
base_adpll,
N,
adpll_write_period,
blind_period,
start_up_delay,
cycle_slip_comp=True,
helper_init_freq=0
):
arr_len = len(time)
main = np.zeros(arr_len, dtype=np.int8)
helper = np.zeros(arr_len, dtype=np.int8)
gtx_tag = 0
gtx_ready = 0
gtx_FF = 0
gtx_beating = np.zeros(arr_len, dtype=np.int8)
main_tag = 0
main_ready = 0
main_FF = 0
main_beating = np.zeros(arr_len, dtype=np.int8)
collector_r = 0
phase_err_arr = np.zeros(arr_len, dtype=np.int16)
period_err_arr = np.zeros(arr_len, dtype=np.int16)
helper_adpll_arr = np.zeros(arr_len, dtype=np.int32)
main_adpll_arr = np.zeros(arr_len, dtype=np.int32)
helperfreq = np.zeros(arr_len, dtype=np.int32)
mainfreq = np.zeros(arr_len, dtype=np.int32)
# initial condition
main_init_offset = (np.random.uniform(0, 360)) / (360 * dcxo_freq)
helper_init_offset = (np.random.uniform(0, 360)) / (360 * dcxo_freq)
# intermediate values
helper_jitter = np.random.normal(0, dcxo_jitter_SD)
main_jitter = np.random.normal(0, dcxo_jitter_SD)
helper_cycle_num = 0
main_cycle_num = 0
if main_pll and not helper_pll:
helper_freq = helper_init_freq
else:
helper_freq = dcxo_freq * (1 + base_adpll * 0.0001164 / 1_000_000) * ((N-1) / N)
main_freq = dcxo_freq * (1 + base_adpll * 0.0001164 / 1_000_000)
FSM_state = 0
coll_gtx_tag = coll_main_tag = 0
last_gtx_tag = last_gtx_FF = last_gtx_beat = 0
last_main_tag = last_main_FF = last_main_beat = 0
last_helper = last_main = 0
counter = 0
gtx_blind_counter = 0
gtx_blinded = False
main_blind_counter = 0
main_blinded = False
wait_gtx = True
wait_main = True
# firmware values
FW_gtx_tag = FW_last_gtx_tag = 0
FW_main_tag = 0
period_err = last_period_err = 0
h_prop = h_integrator = h_derivative = 0
h_adpll = base_adpll
h_i2c_active = False
phase_err = last_phase_err = 0
m_prop = m_integrator = m_derivative = 0
m_adpll = base_adpll
m_i2c_active = False
adpll_active = False
adpll_max = 8161512
def clip(n, minn, maxn): return max(min(maxn, n), minn)
for i, t in enumerate(time):
helper[i], helper_cycle_num, helper_jitter = square(
t + helper_init_offset, helper_freq, dcxo_jitter_SD, helper_jitter, helper_cycle_num)
main[i], main_cycle_num, main_jitter = square(
t + main_init_offset, main_freq, dcxo_jitter_SD, main_jitter, main_cycle_num)
# continuous glitchless output, assume very small frequency change
if h_i2c_active and helper[i] == 1:
helper_freq = dcxo_freq * (1 + h_adpll * 0.0001164 / 1_000_000) * ((N-1) / N)
h_i2c_active = False
if m_i2c_active and main[i] == 1:
main_freq = dcxo_freq * (1 + m_adpll * 0.0001164 / 1_000_000)
m_i2c_active = False
if last_helper == 0 and helper[i] == 1:
gtx_FF, gtx_beating[i] = DDMTD(gtx[i], last_gtx_FF)
main_FF, main_beating[i] = DDMTD(main[i], last_main_FF)
gtx_tag, gtx_ready, gtx_blind_counter, gtx_blinded = Deglitcher(gtx_beating[i], gtx_tag, gtx_ready,
gtx_blind_counter, gtx_blinded, blind_period,
last_gtx_beat, last_gtx_tag, counter)
main_tag, main_ready, main_blind_counter, main_blinded = Deglitcher(main_beating[i], main_tag, main_ready,
main_blind_counter, main_blinded, blind_period,
last_main_beat, last_main_tag, counter)
collector_r, wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state = Collector_FSM(gtx_ready, main_ready, gtx_tag, main_tag,
wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state)
if collector_r == 1:
FW_gtx_tag = coll_gtx_tag
FW_main_tag = coll_main_tag
counter += 1
last_gtx_beat = gtx_beating[i]
last_gtx_FF = gtx_FF
last_gtx_tag = gtx_tag
last_main_beat = main_beating[i]
last_main_FF = main_FF
last_main_tag = main_tag
else:
gtx_beating[i] = last_gtx_beat
gtx_FF = last_gtx_FF
gtx_tag = last_gtx_tag
main_beating[i] = last_main_beat
main_FF = last_main_FF
main_tag = last_main_tag
if i > start_up_delay:
# Firmware filters
if adpll_active and collector_r == 1:
if cycle_slip_comp:
period = FW_gtx_tag - FW_last_gtx_tag
if period > 3 * N/2:
period = period - N
period_err = N - period
else:
period_err = (N - (FW_gtx_tag - FW_last_gtx_tag))
tag_diff = ((FW_main_tag - FW_gtx_tag) % N)
if tag_diff > N/2:
phase_err = tag_diff - N
else:
phase_err = tag_diff
if helper_pll:
h_prop = period_err * h_KP
h_integrator += period_err * h_KI
h_derivative = (period_err - last_period_err) * h_KD
# h_derivative = 0
h_adpll = clip(int(base_adpll + h_prop + h_integrator + h_derivative), -adpll_max, adpll_max)
last_period_err = period_err
h_i2c_active = True
if main_pll:
m_prop = phase_err * m_KP
m_integrator += phase_err * m_KI
m_derivative = (phase_err - last_phase_err) * m_KD
m_adpll = clip(int(base_adpll + m_prop + m_integrator + m_derivative), -adpll_max, adpll_max)
last_phase_err = phase_err
m_i2c_active = True
adpll_active = False
if i % adpll_write_period == 0:
adpll_active = True
FW_last_gtx_tag = FW_gtx_tag
last_helper = helper[i]
last_main = main[i]
# Data
period_err_arr[i] = period_err
phase_err_arr[i] = phase_err
helper_adpll_arr[i] = h_adpll
main_adpll_arr[i] = m_adpll
helperfreq[i] = helper_freq
mainfreq[i] = main_freq
return period_err_arr, phase_err_arr, helper_adpll_arr, main_adpll_arr, gtx_beating, main_beating, helper, main, helperfreq, mainfreq
@jit(nopython=True)
def DDMTD(sig_in, last_FF):
return sig_in, last_FF
@jit(nopython=True)
def Deglitcher(beating, t_out, t_ready, blind_counter, blinded, blind_period, last_beat, last_tag, counter):
if blind_counter == 0 and beating == 1 and last_beat == 0: # rising
t_out = counter
t_ready = 1
blinded = True
else:
t_out = last_tag
t_ready = 0
if beating == 1:
blind_counter = blind_period - 1
if blind_counter != 0:
blind_counter -= 1
return t_out, t_ready, blind_counter, blinded
@jit(nopython=True)
def Collector_FSM(g_tag_r, m_tag_r, gtx_tag, main_tag, wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state):
collector_r = 0
match FSM_state:
case 0: # IDEL
if g_tag_r == 1 and m_tag_r == 1:
coll_gtx_tag = gtx_tag
coll_main_tag = main_tag
FSM_state = 3 # OUTPUT
elif g_tag_r == 1:
coll_gtx_tag = gtx_tag
wait_main = True
FSM_state = 2 # WAITMAIN
elif m_tag_r == 1:
coll_main_tag = main_tag
wait_gtx = True
FSM_state = 1 # WAITGTX
case 1: # WAITGTX
if g_tag_r == 1:
coll_gtx_tag = gtx_tag
FSM_state = 3 # OUTPUT
case 2: # WAITMAIN
if m_tag_r == 1:
coll_main_tag = main_tag
FSM_state = 3 # OUTPUT
case 3: # OUTPUT
wait_gtx = wait_main = False
collector_r = 1
FSM_state = 0
return collector_r, wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state

69
src/wave_gen.py Normal file
View File

@ -0,0 +1,69 @@
import numpy as np
from numba import jit
@jit(nopython=True)
def square(time, freq, jitter_SD, jitter, cycle_num):
"""
A scipy like square wave with jitter
Parameters
----------
time : timestamp, in seconds
freq : frequency in Hz
jitter_SD : standard deviation of jitter, in seconds
jitter : gaussian noise with `jitter_SD` as its standard deviation
cycle_num : time // period = cycle_num
----------
"""
period = 1/(freq)
quarter = (period / 4)
out = 0
# A T/4 shift is applied to the following to match scipy square fn
# ┌──────┐
# ───┘ └────
# T/4 3T/4
nth_cycle, t = np.divmod(time + period / 4, period)
if t >= (quarter + jitter) and t <= (3*quarter + jitter):
out = 1
else:
out = 0
# update jitter every cycle
if nth_cycle != cycle_num:
jitter = np.random.normal(0, jitter_SD)
cycle_num = nth_cycle
return out, cycle_num, jitter
@jit(nopython=True)
def square_arr(time, freq, jitter_SD):
"""
A scipy like square wave with jitter
Parameters
----------
time : numpy array
freq : frequency in Hz
jitter_SD : standard deviation of jitter, in seconds
----------
"""
wave = np.zeros(len(time))
jitter = np.random.normal(0, jitter_SD)
cycle_num = 0
for i, t in enumerate(time):
wave[i], cycle_num, jitter = square(t, freq, jitter_SD, jitter, cycle_num)
return wave