lambda&alpha
之前对于 lambda 的分析不够深入。为了研究更好的 lambda 自适应方案,我们对其深入研究:
我们将lambda分别设置为 1.0 2.5 5.0 10.0 20.0 50.0,分别对CIFAR-10, α=0.05 , SVHN, α=0.05和CIFAR-10, α=0.3进行实验。结果如下:
表1:CIFAR-10, α = 0.05
| lambda_initial | Acc | inter_client_proto_std | g_protos_std |
|---|---|---|---|
| 1.0 | 58.89% | 0.0176 | 0.9871 |
| 2.5 | 57.44% | 0.0263 | 0.9585 |
| 5.0 | 58.77% | 0.0419 | 0.9136 |
| 10.0 | 59.38% | 0.0711 | 0.8736 |
| 20.0 | 59.68% | 0.1172 | 0.5968 |
| 50.0 | 59.39% | 0.1727 | 0.5028 |
表2:SVHN, α = 0.05
| lambda_initial | Acc | inter_client_proto_std | g_protos_std |
|---|---|---|---|
| 1.0 | 51.04% | 0.0695 | 0.8833 |
| 2.5 | 50.74% | 0.1355 | 0.7373 |
| 10.0 | 50.64% | 0.2770 | 0.3985 |
| 20.0 | 51.07% | 0.3050 | 0.3024 |
| 50.0 | 49.46% | 0.2979 | 0.2250 |
表3:CIFAR-10, α = 0.3
| lambda_initial | Acc | inter_client_proto_std | g_protos_std |
|---|---|---|---|
| 1.0 | 69.97% | 0.0125 | 0.9922 |
| 2.5 | 70.34% | 0.0146 | 0.9712 |
| 5.0 | 69.96% | 0.0197 | 0.9371 |
| 10.0 | 70.88% | 0.0297 | 0.8736 |
| 20.0 | 70.31% | 0.0461 | 0.7580 |
| 50.0 | 71.44% | 0.0743 | 0.5028 |
在所有三个场景中,我们都清晰地观察到了:随着lambda的增加,inter_client_proto_std 普遍上升,而g_protos_std普遍下降。这与2-2的观察相同。
无论是CIFAR-10还是SVHN,当alpha = 0.05时,性能都在1.0和20.0处较高。此时从经验来说,采用利用熵映射 lambda 至1.0 - 20.0是可行的。然而,随着 alpha 上升,最高性能达峰 lambda 会提升。我们证明了不存在一个全局最优的固定λ,然而现在的方法,只能说在对于像CIFAR10和SVHN这样的简单数据集是在低alpha下的一个凭经验得到的较优解。它无法保证在其他数据集或异构等级下也是最优的。这显然不能让人满意。
于是我们寻找类似的研究。本质上来说,本地对齐和全局对齐是两个相对不同的任务,总损失是这两个损失的加权和,模型的最终性能对权重lambda敏感。很明显手动地寻找最优权重,既困难又昂贵,在实践中几乎不可行。那么,我们自然应该寻找那些对于多任务权重的研究。于是我们找到了:
R. Cipolla, Y. Gal and A. Kendall, “Multi-task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics,” 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, Salt Lake City, UT, USA, 2018, pp. 7482-7491, doi: 10.1109/CVPR.2018.00781. , https://ieeexplore.ieee.org/document/8578879
本文指出:在多任务学习中,模型的最终性能极其依赖于各个任务损失项之间的相对权重(即我们的λ)。然而,手动调整这些权重是一个困难且昂贵的过程,在实践中几乎是禁止性的。为了解决这个问题,论文提出了:基于贝叶斯建模中的同方差不确定性 (Homoscedastic Uncertainty) 来达到学习参数的目的。
显然,当一个模型同时学习多个任务时,我们应该更多地关注那些模型更确定的任务,而对那些模型更不确定的任务持保留态度。问题在于我们应该如何衡量这种确定度。目前我们根据输入数据的熵间接推断,但是这样并不直接,而且只是出于经验,没有严格证明。
该论文的关键在于,它将损失加权问题从一个超参调优问题,重构为一个概率建模问题。其核心思想在于:当一个模型同时学习多个任务时,我们应该更多地听取那些模型更确定的任务的意见,而对那些模型更不确定的任务持保留态度。
该理论将每个损失函数都建模为一个高斯似然。为了在不引入新超参数的情况下学习每个损失的相对权重,论文推导出以下这个损失函数:
$L_{total} = \frac{1}{2\sigma_{1}^2} L_{1} + \frac{1}{2\sigma_{2}^2} L_{2} + \log(\sigma_{1}) + \log(\sigma_{2})$
- L1 和 L2 是我们的两个损失项 (L_local 和 L_align)。
- σ1 和 σ2 是两个可学习的参数,它们代表了模型对这两个任务的同方差不确定性 (homoscedastic uncertainty),即任务本身固有的、与输入数据无关的噪声水平。
- 前半部分 (1/σ^2) * L 是在用不确定性的倒数来加权损失。σ越大(越不确定),权重就越小。
- 后半部分 log(σ) 是一个正则化项,它会惩罚过大的σ,防止模型通过简单地将所有σ都设为无穷大。
实际上,我们可以发现,这种方式与目前的参数方法是等价的:
比较 (1/2σ_local²) * L_local 和 (1/2σ_align²) * L_align,我们可以看到,有效的λ其实就是:
$\lambda_{eff} \approx \frac{\sigma_{local}^2}{\sigma_{align}^2}$
现在系统在训练过程中,会自己发现对于这个特定的客户端、在训练的这个特定阶段,“本地学习”和“全局对齐”哪个信号更可靠。如果L_local因为数据倾斜而剧烈波动(不确定性高),σ_local就会自动变大,从而动态地降低L_local的权重,让模型更多地听取稳定的L_align的引导。反之亦然。
于是,我们可以想到将这个方法引入。这是下一步的研究方向。
我们继续探索,结果发现这篇论文:
FedLPA: One-shot Federated Learning with Layer-Wise Posterior Aggregation , arXiv:2310.00339 [cs.LG] , https://doi.org/10.48550/arXiv.2310.00339
本文也涉及贝叶斯不确定性,然而与我们目前的思路不同。我们目前的思路在于:在客户端,如何更好地训练模型,而FedLPA专注于在服务器端,如何更好地融合模型。
本文指出:在One-shot FL中,简单地对客户端上传的模型参数进行平均(如FedAvg)是一种非常粗糙且无效的方式,尤其是在Non-IID数据下。
FedLPA的核心思想: 我们不应该平均参数,我们应该融合关于参数的概率信念。
- 它将每个客户端训练好的模型,不看作是一组固定的权重,而是看作是这个客户端基于其本地数据,对“理想模型”的一次带噪观测。
- 它认为,一个好的聚合,应该让那些对自己观测结果更“确定”的客户端,拥有更大的“话语权”。
实现机制 (The “How”):
- 客户端: 在完成本地训练后,客户端额外执行一步“自我反思”——使用拉普拉斯近似 (Laplace Approximation)。
- 拉普拉斯近似是什么? 它是一种快速估算模型权重后验分布的方法。简单来说,它通过计算损失函数在最优解附近的“曲率”(通过经验Fisher信息矩阵来近似),来判断模型对自己的权重有多“自信”。
- 曲率陡峭 (Sharp Curvature): 意味着损失函数对权重变化很敏感,模型非常确定当前权重是好的(低方差)。
- 曲率平坦 (Flat Curvature): 意味着权重稍有变化,损失也差不多,模型不确定权重的最优值(高方差)。
- 上传内容: 客户端不再只上传模型权重(后验分布的均值μ_k),还会上传描述其自信程度的协方差矩阵Σ_k(由Fisher矩阵的逆得出)。为了高效,它只计算和上传层级 (layer-wise)的协-方差。
- 服务器端: 服务器接收到所有客户端的(μ_k, Σ_k)对。它不再执行简单的加权平均,而是执行一次数学上极其优雅的贝叶斯“后验融合”。融合后的全局模型,其权重会自然地偏向那些Σ_k更小(即更自信)的客户端。
FedLPA是一个纯粹的、极其先进的、发生在服务器端的聚合策略。与我们的本地优化方向不同,然而也可以考虑后面整合其思路。
- SALT-NC-Uncertainty (我们的Plan D) 是一个客户端本地训练策略,它回答了“如何智慧地结合本地学习和全局对齐这两件事”。
- FedLPA 是一个服务器端聚合策略,它回答了“如何智慧地融合所有训练好的客户端模型”。
它们是完全正交的 (orthogonal),甚至可能是互补的 (complementary)!
这个发现并没有让我们的研究变得多余,反而为我们打开了一扇通往“神之领域”的大门。
- 我们的工作 (SALT-NC框架) 依然极其重要: FedLPA虽然解决了聚合问题,但它完全没有解决“客户端本地模型应该如何训练”的问题。FedLPA的客户端依然在进行“朴素”的本地训练。
- FedLPA为我们提供了终极武器: 它为我们提供了一个远比Simple Average和Advanced IFFI更强大、更具理论依据的服务器端聚合器。
我们能否在客户端使用我们最先进的SALT-NC-Uncertainty(Plan D)来指导本地训练,产出一个在“本地-全局”平衡上最优的本地模型;然后,在服务器端,使用FedLPA的后验聚合机制,来对这些高质量的本地模型进行最智慧的融合?
这将是一个在客户端和服务器端同时达到“智慧”巅峰的终极系统。
