DHCL + Progressive Alignment
参考之前的 lambda 退火,我们不再要求本地模型一步到位地对齐最终的ETF锚点,而是让对齐目标从一个容易的初始状态,平滑地、渐进地过渡到完美的最终状态。
结论
数据见 progressive-alignment.txt。
1. 数据分析
a. 模型一致性
alpha=0.05
:g_protos_std
从1.00606
(Round 0) 缓慢增长到1.00648
(Round 30)。alpha=0.1
:g_protos_std
从1.00604
(Round 0) 缓慢增长到1.00634
(Round 20)。alpha=0.3
:g_protos_std
从1.00598
(Round 0) 缓慢增长到1.00627
(Round 20)。alpha=0.5
:g_protos_std
从1.00594
(Round 0) 缓慢增长到1.00620
(Round 20)。
这个指标表明,客户端上传的本地原型仍然是高度发散的,彼此之间几乎没有对齐。
b. 准确率增长
观察任意一个alpha值的场景,例如 alpha=0.05
:
OursV8
的准确率在第10轮达到 48.47%,在第30轮达到 61.15%。- 回顾我们之前的成功实验,
OursV6
在第10轮已经达到 50.93%,在第30轮更是达到了 63.65%。
OursV8
的学习速度和最终性能都落后于之前的最佳版本。
c. 本地训练准确率
以 alpha=0.05
,Round 0
为例:
- Client 0: train accuracy: 0.269
- Client 1: train accuracy: 0.843
- Client 2: train accuracy: 0.847
- Client 3: train accuracy: 0.360
这个巨大的差异揭示了本地数据分布的极端不平衡。Client 1和2的本地数据可能很容易学习,因此它们的base_loss
会迅速下降,并产生一个极强的梯度,将learnable_proto
拉向这个“局部最优”的位置。
据推测,base_loss
和align_loss
的力量失衡可能造成该问题,故而接下来通过将 lambda_align_initial
调整至20继续实验。
数据见 progressive-alignment2.txt。
核心结论
1. 数据分析
a. 模型一致性 (g_protos_std
)
异构程度 (α) | 算法 | g_protos_std (第50轮) | 诊断 |
---|---|---|---|
0.05 | V6 (Best) | 0.70938 | 有效对齐 |
V8 (λ=20) | 1.00648 | 对齐完全失败 | |
0.1 | V7 (Best) | 0.65451 | 有效对齐 |
V8 (λ=20) | 1.00648 | 对齐完全失败 |
将lambda
从5.0
提升到20.0
,对最终的模型一致性没有任何改善。这证明了问题不在于lambda
的绝对值大小,而在于对齐力量施加的时机。
b. 最终准确率
alpha=0.1
:V8(λ=20)
获得了 76.88% 的准确率。这个结果几乎与V7
的76.86%完全相同。但这很可能是一个巧合,V7
的成功源于其高质量的ETF锚点引导,而V8
的这个结果更像是无对齐的随机表现。alpha=0.05
:V8(λ=20)
获得了 65.58% 的准确率。这显著低于V6
的68.17%,证明了在最困难的场景下,失效的对齐机制导致了严重的性能惩罚。
2. 原因分析
实验结果表明,本地数据驱动的base_loss
梯度在训练的最初几个批次中是如此之大,以至于它会瞬间将可学习原型learnable_proto
钉死在本地数据的最优位置上。
这个过程发生在第一个epoch之内,甚至在我们的lambda
退火机制有机会施加有意义的影响之前。一旦原型被“污染”,后续再强的对齐力量也难以将其拉回到一个理想的全局位置。
在深度学习训练的 t=0
时刻,模型参数是随机的,输出也是随机的。此时,CrossEntropyLoss会产生巨大的损失值和梯度。
在优化器的第一次更新中,base_loss
的梯度占据了绝对主导地位。在几个mini-batch之后,learnable_proto
就已经被塑造成了只服务于本地数据的形态。