diff --git a/gateware/logic/pid.py b/gateware/logic/pid.py
index 4320f94..e737577 100644
--- a/gateware/logic/pid.py
+++ b/gateware/logic/pid.py
@@ -56,10 +56,13 @@ class PID(Module, AutoCSR):
         self.comb += [kp_signed.eq(self.kp.storage)]
 
         kp_mult = Signal((self.width + self.coeff_width, True))
-        self.comb += [kp_mult.eq(self.error * kp_signed)]
+        kp_mult_reg = Signal((self.width + self.coeff_width, True))
+        self.sync += kp_mult.eq(kp_mult_reg >> (self.coeff_width - 2))
+
+        self.comb += [kp_mult_reg.eq(self.error * kp_signed)]
 
         self.output_p = Signal((self.width, True))
-        self.comb += [self.output_p.eq(kp_mult >> (self.coeff_width - 2))]
+        self.comb += [self.output_p.eq(kp_mult)]
 
         self.kp_mult = kp_mult
 
@@ -71,8 +74,10 @@ class PID(Module, AutoCSR):
         self.comb += [ki_signed.eq(self.ki.storage)]
 
         self.ki_mult = Signal((1 + self.width + self.coeff_width, True))
+        self.ki_mult_reg = Signal((1 + self.width + self.coeff_width, True))
+        self.sync += self.ki_mult.eq(self.ki_mult_reg)
+        self.comb += self.ki_mult_reg.eq((self.error * ki_signed) >> 4)
 
-        self.comb += [self.ki_mult.eq((self.error * ki_signed) >> 4)]
 
         int_reg_width = self.width + self.coeff_width + 4
         extra_width = int_reg_width - self.width
@@ -110,15 +115,17 @@ class PID(Module, AutoCSR):
         self.kd = CSRStorage(self.coeff_width)
         kd_signed = Signal((self.coeff_width, True))
         kd_mult = Signal((mult_width, True))
+        kd_mult_reg = Signal((mult_width, True))
+        self.sync += kd_mult.eq(kd_mult_reg)
 
-        self.comb += [kd_signed.eq(self.kd.storage), kd_mult.eq(self.error * kd_signed)]
+        self.comb += [kd_signed.eq(self.kd.storage), kd_mult_reg.eq(self.error * kd_signed >> (self.coeff_width - self.d_shift))]
 
         kd_reg = Signal((out_width, True))
         kd_reg_r = Signal((out_width, True))
 
         self.output_d = Signal((out_width, True))
         self.sync += [
-            kd_reg.eq(kd_mult >> (self.coeff_width - self.d_shift)),
+            kd_reg.eq(kd_mult),
             kd_reg_r.eq(kd_reg),
             self.output_d.eq(kd_reg - kd_reg_r),
         ]
@@ -143,4 +150,10 @@ class PID(Module, AutoCSR):
 
         # sync is required here, otherwise we get artifacts when one of the
         # signals changes sign
-        self.sync += [self.pid_sum.eq(self.output_p + self.int_out + self.output_d)]
+        self.sync += [
+            If(
+                self.running,
+                self.pid_sum.eq(self.output_p + self.int_out + self.output_d),
+            )
+            .Else(self.pid_sum.eq(0))
+        ]