WRPLL time domain simulation
wave_gen: add square wave with jitter generator sim: add WRPLL gateware and firmware simulation sim: add PID for main and helper PLL sim: add options to choose PLL modes sim: add cycle slip compensation for helper PLL sim: optimize execution time with numba jit
This commit is contained in:
parent
72f47668c9
commit
cbc9772efa
|
@ -0,0 +1,274 @@
|
|||
import numpy as np
|
||||
from numba import jit
|
||||
from wave_gen import square
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def simulation_jit(
|
||||
time,
|
||||
gtx,
|
||||
helper_pll,
|
||||
main_pll,
|
||||
h_KP,
|
||||
h_KI,
|
||||
h_KD,
|
||||
m_KP,
|
||||
m_KI,
|
||||
m_KD,
|
||||
dcxo_freq,
|
||||
dcxo_jitter_SD,
|
||||
base_adpll,
|
||||
N,
|
||||
adpll_write_period,
|
||||
blind_period,
|
||||
start_up_delay,
|
||||
cycle_slip_comp=True,
|
||||
helper_init_freq=0
|
||||
):
|
||||
|
||||
arr_len = len(time)
|
||||
|
||||
main = np.zeros(arr_len, dtype=np.int8)
|
||||
helper = np.zeros(arr_len, dtype=np.int8)
|
||||
|
||||
gtx_tag = 0
|
||||
gtx_ready = 0
|
||||
gtx_FF = 0
|
||||
gtx_beating = np.zeros(arr_len, dtype=np.int8)
|
||||
|
||||
main_tag = 0
|
||||
main_ready = 0
|
||||
main_FF = 0
|
||||
main_beating = np.zeros(arr_len, dtype=np.int8)
|
||||
collector_r = 0
|
||||
|
||||
phase_err_arr = np.zeros(arr_len, dtype=np.int16)
|
||||
period_err_arr = np.zeros(arr_len, dtype=np.int16)
|
||||
helper_adpll_arr = np.zeros(arr_len, dtype=np.int32)
|
||||
main_adpll_arr = np.zeros(arr_len, dtype=np.int32)
|
||||
|
||||
helperfreq = np.zeros(arr_len, dtype=np.int32)
|
||||
mainfreq = np.zeros(arr_len, dtype=np.int32)
|
||||
|
||||
# initial condition
|
||||
main_init_offset = (np.random.uniform(0, 360)) / (360 * dcxo_freq)
|
||||
helper_init_offset = (np.random.uniform(0, 360)) / (360 * dcxo_freq)
|
||||
|
||||
# intermediate values
|
||||
helper_jitter = np.random.normal(0, dcxo_jitter_SD)
|
||||
main_jitter = np.random.normal(0, dcxo_jitter_SD)
|
||||
helper_cycle_num = 0
|
||||
main_cycle_num = 0
|
||||
|
||||
if main_pll and not helper_pll:
|
||||
helper_freq = helper_init_freq
|
||||
else:
|
||||
helper_freq = dcxo_freq * (1 + base_adpll * 0.0001164 / 1_000_000) * ((N-1) / N)
|
||||
main_freq = dcxo_freq * (1 + base_adpll * 0.0001164 / 1_000_000)
|
||||
|
||||
FSM_state = 0
|
||||
coll_gtx_tag = coll_main_tag = 0
|
||||
last_gtx_tag = last_gtx_FF = last_gtx_beat = 0
|
||||
last_main_tag = last_main_FF = last_main_beat = 0
|
||||
|
||||
last_helper = last_main = 0
|
||||
|
||||
counter = 0
|
||||
gtx_blind_counter = 0
|
||||
gtx_blinded = False
|
||||
|
||||
main_blind_counter = 0
|
||||
main_blinded = False
|
||||
|
||||
wait_gtx = True
|
||||
wait_main = True
|
||||
|
||||
# firmware values
|
||||
FW_gtx_tag = FW_last_gtx_tag = 0
|
||||
FW_main_tag = 0
|
||||
|
||||
period_err = last_period_err = 0
|
||||
h_prop = h_integrator = h_derivative = 0
|
||||
h_adpll = base_adpll
|
||||
h_i2c_active = False
|
||||
|
||||
phase_err = last_phase_err = 0
|
||||
m_prop = m_integrator = m_derivative = 0
|
||||
m_adpll = base_adpll
|
||||
m_i2c_active = False
|
||||
|
||||
adpll_active = False
|
||||
adpll_max = 8161512
|
||||
def clip(n, minn, maxn): return max(min(maxn, n), minn)
|
||||
|
||||
for i, t in enumerate(time):
|
||||
|
||||
helper[i], helper_cycle_num, helper_jitter = square(
|
||||
t + helper_init_offset, helper_freq, dcxo_jitter_SD, helper_jitter, helper_cycle_num)
|
||||
main[i], main_cycle_num, main_jitter = square(
|
||||
t + main_init_offset, main_freq, dcxo_jitter_SD, main_jitter, main_cycle_num)
|
||||
|
||||
# continuous glitchless output, assume very small frequency change
|
||||
if h_i2c_active and helper[i] == 1:
|
||||
helper_freq = dcxo_freq * (1 + h_adpll * 0.0001164 / 1_000_000) * ((N-1) / N)
|
||||
h_i2c_active = False
|
||||
|
||||
if m_i2c_active and main[i] == 1:
|
||||
main_freq = dcxo_freq * (1 + m_adpll * 0.0001164 / 1_000_000)
|
||||
m_i2c_active = False
|
||||
|
||||
if last_helper == 0 and helper[i] == 1:
|
||||
|
||||
gtx_FF, gtx_beating[i] = DDMTD(gtx[i], last_gtx_FF)
|
||||
main_FF, main_beating[i] = DDMTD(main[i], last_main_FF)
|
||||
|
||||
gtx_tag, gtx_ready, gtx_blind_counter, gtx_blinded = Deglitcher(gtx_beating[i], gtx_tag, gtx_ready,
|
||||
gtx_blind_counter, gtx_blinded, blind_period,
|
||||
last_gtx_beat, last_gtx_tag, counter)
|
||||
|
||||
main_tag, main_ready, main_blind_counter, main_blinded = Deglitcher(main_beating[i], main_tag, main_ready,
|
||||
main_blind_counter, main_blinded, blind_period,
|
||||
last_main_beat, last_main_tag, counter)
|
||||
|
||||
collector_r, wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state = Collector_FSM(gtx_ready, main_ready, gtx_tag, main_tag,
|
||||
wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state)
|
||||
|
||||
if collector_r == 1:
|
||||
FW_gtx_tag = coll_gtx_tag
|
||||
FW_main_tag = coll_main_tag
|
||||
|
||||
counter += 1
|
||||
|
||||
last_gtx_beat = gtx_beating[i]
|
||||
last_gtx_FF = gtx_FF
|
||||
last_gtx_tag = gtx_tag
|
||||
last_main_beat = main_beating[i]
|
||||
last_main_FF = main_FF
|
||||
last_main_tag = main_tag
|
||||
|
||||
else:
|
||||
gtx_beating[i] = last_gtx_beat
|
||||
gtx_FF = last_gtx_FF
|
||||
gtx_tag = last_gtx_tag
|
||||
|
||||
main_beating[i] = last_main_beat
|
||||
main_FF = last_main_FF
|
||||
main_tag = last_main_tag
|
||||
|
||||
if i > start_up_delay:
|
||||
|
||||
# Firmware filters
|
||||
if adpll_active and collector_r == 1:
|
||||
if cycle_slip_comp:
|
||||
period = FW_gtx_tag - FW_last_gtx_tag
|
||||
if period > 3 * N/2:
|
||||
period = period - N
|
||||
period_err = N - period
|
||||
else:
|
||||
period_err = (N - (FW_gtx_tag - FW_last_gtx_tag))
|
||||
|
||||
tag_diff = ((FW_main_tag - FW_gtx_tag) % N)
|
||||
if tag_diff > N/2:
|
||||
phase_err = tag_diff - N
|
||||
else:
|
||||
phase_err = tag_diff
|
||||
|
||||
if helper_pll:
|
||||
h_prop = period_err * h_KP
|
||||
h_integrator += period_err * h_KI
|
||||
h_derivative = (period_err - last_period_err) * h_KD
|
||||
# h_derivative = 0
|
||||
|
||||
h_adpll = clip(int(base_adpll + h_prop + h_integrator + h_derivative), -adpll_max, adpll_max)
|
||||
last_period_err = period_err
|
||||
h_i2c_active = True
|
||||
|
||||
if main_pll:
|
||||
m_prop = phase_err * m_KP
|
||||
m_integrator += phase_err * m_KI
|
||||
m_derivative = (phase_err - last_phase_err) * m_KD
|
||||
|
||||
m_adpll = clip(int(base_adpll + m_prop + m_integrator + m_derivative), -adpll_max, adpll_max)
|
||||
last_phase_err = phase_err
|
||||
m_i2c_active = True
|
||||
|
||||
adpll_active = False
|
||||
|
||||
if i % adpll_write_period == 0:
|
||||
adpll_active = True
|
||||
FW_last_gtx_tag = FW_gtx_tag
|
||||
|
||||
last_helper = helper[i]
|
||||
last_main = main[i]
|
||||
|
||||
# Data
|
||||
period_err_arr[i] = period_err
|
||||
phase_err_arr[i] = phase_err
|
||||
helper_adpll_arr[i] = h_adpll
|
||||
main_adpll_arr[i] = m_adpll
|
||||
helperfreq[i] = helper_freq
|
||||
mainfreq[i] = main_freq
|
||||
|
||||
return period_err_arr, phase_err_arr, helper_adpll_arr, main_adpll_arr, gtx_beating, main_beating, helper, main, helperfreq, mainfreq
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def DDMTD(sig_in, last_FF):
|
||||
return sig_in, last_FF
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def Deglitcher(beating, t_out, t_ready, blind_counter, blinded, blind_period, last_beat, last_tag, counter):
|
||||
|
||||
if blind_counter == 0 and beating == 1 and last_beat == 0: # rising
|
||||
t_out = counter
|
||||
t_ready = 1
|
||||
blinded = True
|
||||
else:
|
||||
t_out = last_tag
|
||||
t_ready = 0
|
||||
|
||||
if beating == 1:
|
||||
blind_counter = blind_period - 1
|
||||
|
||||
if blind_counter != 0:
|
||||
blind_counter -= 1
|
||||
|
||||
return t_out, t_ready, blind_counter, blinded
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def Collector_FSM(g_tag_r, m_tag_r, gtx_tag, main_tag, wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state):
|
||||
|
||||
collector_r = 0
|
||||
|
||||
match FSM_state:
|
||||
case 0: # IDEL
|
||||
if g_tag_r == 1 and m_tag_r == 1:
|
||||
coll_gtx_tag = gtx_tag
|
||||
coll_main_tag = main_tag
|
||||
FSM_state = 3 # OUTPUT
|
||||
|
||||
elif g_tag_r == 1:
|
||||
coll_gtx_tag = gtx_tag
|
||||
wait_main = True
|
||||
FSM_state = 2 # WAITMAIN
|
||||
|
||||
elif m_tag_r == 1:
|
||||
coll_main_tag = main_tag
|
||||
wait_gtx = True
|
||||
FSM_state = 1 # WAITGTX
|
||||
case 1: # WAITGTX
|
||||
if g_tag_r == 1:
|
||||
coll_gtx_tag = gtx_tag
|
||||
FSM_state = 3 # OUTPUT
|
||||
case 2: # WAITMAIN
|
||||
if m_tag_r == 1:
|
||||
coll_main_tag = main_tag
|
||||
FSM_state = 3 # OUTPUT
|
||||
case 3: # OUTPUT
|
||||
wait_gtx = wait_main = False
|
||||
collector_r = 1
|
||||
FSM_state = 0
|
||||
|
||||
return collector_r, wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state
|
|
@ -0,0 +1,69 @@
|
|||
import numpy as np
|
||||
from numba import jit
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def square(time, freq, jitter_SD, jitter, cycle_num):
|
||||
"""
|
||||
|
||||
A scipy like square wave with jitter
|
||||
Parameters
|
||||
----------
|
||||
time : timestamp, in seconds
|
||||
|
||||
freq : frequency in Hz
|
||||
|
||||
jitter_SD : standard deviation of jitter, in seconds
|
||||
|
||||
jitter : gaussian noise with `jitter_SD` as its standard deviation
|
||||
|
||||
cycle_num : time // period = cycle_num
|
||||
----------
|
||||
|
||||
"""
|
||||
|
||||
period = 1/(freq)
|
||||
quarter = (period / 4)
|
||||
out = 0
|
||||
|
||||
# A T/4 shift is applied to the following to match scipy square fn
|
||||
# ┌──────┐
|
||||
# ───┘ └────
|
||||
# T/4 3T/4
|
||||
nth_cycle, t = np.divmod(time + period / 4, period)
|
||||
if t >= (quarter + jitter) and t <= (3*quarter + jitter):
|
||||
out = 1
|
||||
else:
|
||||
out = 0
|
||||
|
||||
# update jitter every cycle
|
||||
if nth_cycle != cycle_num:
|
||||
jitter = np.random.normal(0, jitter_SD)
|
||||
cycle_num = nth_cycle
|
||||
|
||||
return out, cycle_num, jitter
|
||||
|
||||
|
||||
@jit(nopython=True)
|
||||
def square_arr(time, freq, jitter_SD):
|
||||
"""
|
||||
|
||||
A scipy like square wave with jitter
|
||||
Parameters
|
||||
----------
|
||||
time : numpy array
|
||||
|
||||
freq : frequency in Hz
|
||||
|
||||
jitter_SD : standard deviation of jitter, in seconds
|
||||
----------
|
||||
|
||||
"""
|
||||
wave = np.zeros(len(time))
|
||||
jitter = np.random.normal(0, jitter_SD)
|
||||
cycle_num = 0
|
||||
|
||||
for i, t in enumerate(time):
|
||||
wave[i], cycle_num, jitter = square(t, freq, jitter_SD, jitter, cycle_num)
|
||||
|
||||
return wave
|
Loading…
Reference in New Issue