A Fast Optimization View: Reformulating Single Layer Attention in LLM Based on Tensor and SVM Trick, and Solving It in Matrix Multiplication Time

Abstract

Large language models (LLMs) have played a pivotal role in revolutionizing various facets of our daily existence. Solving attention regression is a fundamental task in optimizing LLMs. In this work, we focus on giving a provable guarantee for the one-layer attention network objective function L(X,Y)=βˆ‘j0=1nβˆ‘i0=1d(⟨⟨exp⁑(Aj0x),1nβŸ©βˆ’1exp⁑(Aj0x),A3Yβˆ—,i0βŸ©βˆ’bj0,i0)2L(X,Y) = \sum_{j_0 = 1}^n \sum_{i_0 = 1}^d ( \langle \langle \exp( \mathsf{A}_{j_0} x ) , {\bf 1}_n \rangle^{-1} \exp( \mathsf{A}_{j_0} x ), A_{3} Y_{*,i_0} \rangle - b_{j_0,i_0} )^2. Here A∈Rn2Γ—d2\mathsf{A} \in \mathbb{R}^{n^2 \times d^2} is Kronecker product between A1∈RnΓ—dA_1 \in \mathbb{R}^{n \times d} and A2∈RnΓ—dA_2 \in \mathbb{R}^{n \times d}. A3A_3 is a matrix in RnΓ—d\mathbb{R}^{n \times d}, Aj0∈RnΓ—d2\mathsf{A}_{j_0} \in \mathbb{R}^{n \times d^2} is the j0j_0-th block of A\mathsf{A}. The X,Y∈RdΓ—dX, Y \in \mathbb{R}^{d \times d} are variables we want to learn. B∈RnΓ—dB \in \mathbb{R}^{n \times d} and bj0,i0∈Rb_{j_0,i_0} \in \mathbb{R} is one entry at j0j_0-th row and i0i_0-th column of BB, Yβˆ—,i0∈RdY_{*,i_0} \in \mathbb{R}^d is the i0i_0-column vector of YY, and x∈Rd2x \in \mathbb{R}^{d^2} is the vectorization of XX. In a multi-layer LLM network, the matrix B∈RnΓ—dB \in \mathbb{R}^{n \times d} can be viewed as the output of a layer, and A1=A2=A3∈RnΓ—dA_1= A_2 = A_3 \in \mathbb{R}^{n \times d} can be viewed as the input of a layer. The matrix version of xx can be viewed as QK⊀QK^\top and YY can be viewed as VV. We provide an iterative greedy algorithm to train loss function L(X,Y)L(X,Y) up Ο΅\epsilon that runs in O~((Tmat(n,n,d)+Tmat(n,d,d)+d2Ο‰)log⁑(1/Ο΅))\widetilde{O}( ({\cal T}_{\mathrm{mat}}(n,n,d) + {\cal T}_{\mathrm{mat}}(n,d,d) + d^{2\omega}) \log(1/\epsilon) ) time. Here Tmat(a,b,c){\cal T}_{\mathrm{mat}}(a,b,c) denotes the time of multiplying aΓ—ba \times b matrix another bΓ—cb \times c matrix, and Ο‰β‰ˆ2.37\omega\approx 2.37 denotes the exponent of matrix multiplication

    Similar works

    Full text

    thumbnail-image

    Available Versions