sim: improve code style and cleanup

sim & wave_gen: use njit instead of jit(nopython=True)
sim: removing extra == 1
wrapper: remove unused import
This commit is contained in:
morgan 2023-12-12 10:27:22 +08:00
parent 4b098c0a54
commit da5908e754
3 changed files with 21 additions and 23 deletions

View File

@ -1,9 +1,9 @@
import numpy as np
from numba import jit
from numba import njit
from wave_gen import square
@jit(nopython=True)
@njit
def simulation_jit(
time,
gtx,
@ -109,15 +109,15 @@ def simulation_jit(
t + main_init_offset, main_freq, dcxo_jitter_SD, main_jitter, main_cycle_num)
# continuous glitchless output, assume very small frequency change
if h_i2c_active and helper[i] == 1:
if h_i2c_active and helper[i]:
helper_freq = dcxo_freq * (1 + h_adpll * 0.0001164 / 1_000_000) * ((N-1) / N)
h_i2c_active = False
if m_i2c_active and main[i] == 1:
if m_i2c_active and main[i]:
main_freq = dcxo_freq * (1 + m_adpll * 0.0001164 / 1_000_000)
m_i2c_active = False
if last_helper == 0 and helper[i] == 1:
if not last_helper and helper[i]:
gtx_FF, gtx_beating[i] = DDMTD(gtx[i], last_gtx_FF)
main_FF, main_beating[i] = DDMTD(main[i], last_main_FF)
@ -133,7 +133,7 @@ def simulation_jit(
collector_r, wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state = Collector_FSM(gtx_ready, main_ready, gtx_tag, main_tag,
wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state)
if collector_r == 1:
if collector_r:
FW_gtx_tag = coll_gtx_tag
FW_main_tag = coll_main_tag
@ -158,7 +158,7 @@ def simulation_jit(
if i > start_up_delay:
# Firmware filters
if adpll_active and collector_r == 1:
if adpll_active and collector_r:
if cycle_slip_comp:
period = FW_gtx_tag - FW_last_gtx_tag
if period > 3 * N/2:
@ -172,12 +172,11 @@ def simulation_jit(
phase_err = tag_diff - N
else:
phase_err = tag_diff
if helper_pll:
h_prop = period_err * h_KP
h_integrator += period_err * h_KI
h_derivative = (period_err - last_period_err) * h_KD
# h_derivative = 0
h_adpll = clip(int(base_adpll + h_prop + h_integrator + h_derivative), -adpll_max, adpll_max)
last_period_err = period_err
@ -212,15 +211,15 @@ def simulation_jit(
return period_err_arr, phase_err_arr, helper_adpll_arr, main_adpll_arr, gtx_beating, main_beating, helper, main, helperfreq, mainfreq
@jit(nopython=True)
@njit
def DDMTD(sig_in, last_FF):
return sig_in, last_FF
@jit(nopython=True)
@njit
def Deglitcher(beating, t_out, t_ready, blind_counter, blinded, blind_period, last_beat, last_tag, counter):
if blind_counter == 0 and beating == 1 and last_beat == 0: # rising
if blind_counter == 0 and beating and not last_beat: # rising
t_out = counter
t_ready = 1
blinded = True
@ -228,7 +227,7 @@ def Deglitcher(beating, t_out, t_ready, blind_counter, blinded, blind_period, la
t_out = last_tag
t_ready = 0
if beating == 1:
if beating:
blind_counter = blind_period - 1
if blind_counter != 0:
@ -237,33 +236,33 @@ def Deglitcher(beating, t_out, t_ready, blind_counter, blinded, blind_period, la
return t_out, t_ready, blind_counter, blinded
@jit(nopython=True)
@njit
def Collector_FSM(g_tag_r, m_tag_r, gtx_tag, main_tag, wait_gtx, wait_main, coll_gtx_tag, coll_main_tag, FSM_state):
collector_r = 0
match FSM_state:
case 0: # IDEL
if g_tag_r == 1 and m_tag_r == 1:
if g_tag_r and m_tag_r:
coll_gtx_tag = gtx_tag
coll_main_tag = main_tag
FSM_state = 3 # OUTPUT
elif g_tag_r == 1:
elif g_tag_r:
coll_gtx_tag = gtx_tag
wait_main = True
FSM_state = 2 # WAITMAIN
elif m_tag_r == 1:
elif m_tag_r:
coll_main_tag = main_tag
wait_gtx = True
FSM_state = 1 # WAITGTX
case 1: # WAITGTX
if g_tag_r == 1:
if g_tag_r:
coll_gtx_tag = gtx_tag
FSM_state = 3 # OUTPUT
case 2: # WAITMAIN
if m_tag_r == 1:
if m_tag_r:
coll_main_tag = main_tag
FSM_state = 3 # OUTPUT
case 3: # OUTPUT

View File

@ -1,8 +1,8 @@
import numpy as np
from numba import jit
from numba import njit
@jit(nopython=True)
@njit
def square(time, freq, jitter_SD, jitter, cycle_num):
"""
@ -44,7 +44,7 @@ def square(time, freq, jitter_SD, jitter, cycle_num):
return out, cycle_num, jitter
@jit(nopython=True)
@njit
def square_arr(time, freq, jitter_SD):
"""

View File

@ -1,5 +1,4 @@
import numpy as np
from numba import jit
from wave_gen import square_arr
from sim import simulation_jit