Data heterogeneity across clients is a key challenge in federated learning.
Prior works address this by either aligning client and server models or using
control variates to correct client model drift. Although these methods achieve
fast convergence in convex or simple non-convex problems, the performance in
over-parameterized models such as deep neural networks is lacking. In this
paper, we first revisit the widely used FedAvg algorithm in a deep neural
network to understand how data heterogeneity influences the gradient updates
across the neural network layers. We observe that while the feature extraction
layers are learned efficiently by FedAvg, the substantial diversity of the
final classification layers across clients impedes the performance. Motivated
by this, we propose to correct model drift by variance reduction only on the
final layers. We demonstrate that this significantly outperforms existing
benchmarks at a similar or lower communication cost. We furthermore provide
proof for the convergence rate of our algorithm.Comment: Accepted to CVPR 202