Deep neural networks have shown the ability to extract universal feature
representations from data such as images and text that have been useful for a
variety of learning tasks. However, the fruits of representation learning have
yet to be fully-realized in federated settings. Although data in federated
settings is often non-i.i.d. across clients, the success of centralized deep
learning suggests that data often shares a global feature representation, while
the statistical heterogeneity across clients or tasks is concentrated in the
labels. Based on this intuition, we propose a novel federated learning
framework and algorithm for learning a shared data representation across
clients and unique local heads for each client. Our algorithm harnesses the
distributed computational power across clients to perform many local-updates
with respect to the low-dimensional local parameters for every update of the
representation. We prove that this method obtains linear convergence to the
ground-truth representation with near-optimal sample complexity in a linear
setting, demonstrating that it can efficiently reduce the problem dimension for
each client. Further, we provide extensive experimental results demonstrating
the improvement of our method over alternative personalized federated learning
approaches in heterogeneous settings