核心概念
大規模言語モデルのファインチューニングにおける課題であるメモリ使用量の増大に対し、本稿では、スパースチューニングと低ランク行列を用いた表現学習を組み合わせたSNELL (Sparse tuning with kerNElized LoRA) を提案する。SNELLは、従来のスパースチューニングの手法と比較して、メモリ使用量を大幅に削減しながらも、同等以上の性能を実現する。
要約
メモリ使用量を抑えたスパースチューニング:SNELLの紹介
大規模な事前学習済みモデルを、限られた学習データで特定のタスクに適応させる際に、ファインチューニングが広く用いられています。しかし、モデルのパラメータ全体をファインチューニングすることは、メモリ使用量が多く、過剰適合のリスクも高いため、大規模モデルでは非効率です。
これを解決するために、パラメータ効率の良いファインチューニング(PEFT)が注目されています。PEFTは、モデルパラメータの一部のみを調整することで、メモリ使用量と過剰適合のリスクを抑えつつ、高い性能を実現します。PEFTは、追加ベースの手法と再パラメータ化ベースの手法に大別されます。
追加ベースの手法は、事前学習済みのバックボーンにパラメータを追加して学習します。AdapterやPrompt-tuningなどが代表的な手法です。
再パラメータ化ベースの手法は、事前学習済みのバックボーンのパラメータを直接調整します。BitFit、Partial-k、LoRAなどが代表的な手法です。
スパースチューニングの課題
近年、PEFTにおいて、タスクに関連性の高い重みのみを調整するスパースチューニングが注目されています。スパースチューニングは、行列内の個々の重みに着目することで、より正確な調整を可能にし、高い性能と過剰適合リスクの軽減を実現します。
しかし、スパースチューニングは、高いメモリ使用量という課題を抱えています。スパースチューニングでは、事前学習済み重み行列の一部のみを更新しますが、実際には行列全体を学習可能なパラメータとしてオプティマイザに格納し、対応する勾配を計算する必要があります。さらに、調整可能な重みのインデックスを格納する必要があるため、メモリ使用量がさらに増加します。
本稿で提案するSNELLは、スパースチューニングとLoRAを組み合わせることで、低いメモリ使用量と高い性能を両立させる手法です。
メモリ使用量の削減
SNELLは、スパース化のために調整可能な行列を低ランクの学習可能な行列に分解することで、オプティマイザに格納するパラメータを減らし、メモリ使用量を削減します。
競合ベースのスパース化メカニズム
従来のスパース化手法では、調整可能な重みのインデックスを格納する必要がありましたが、SNELLは、競合ベースのパラメータスパース化メカニズムを採用することで、このインデックスの格納を不要にします。
このメカニズムは、神経科学におけるニューロン競合現象に着想を得ており、重みをその絶対値に基づいて競合させます。タスクに関連性の高い重みは、より大きな絶対値を持つように促され、ファインチューニングプロセス中に残ります。スパース比をハイパーパラメータとして設定し、調整可能な重みをその絶対値に基づいてエンドツーエンドで決定することで、調整可能な重みのインデックスを格納する必要がなくなります。
カーネル化LoRA
SNELLは、低ランク行列を用いつつ、高い性能を実現するために、カーネル化LoRAを採用しています。
LoRAでは、低ランク行列AとBをマージして、事前学習済み重み行列W0に適用します。
W = W0 + ∆W = W0 + BA⊤
しかし、∆Wはランクrの低ランク行列であるため、スパースチューニングの重み最適化の範囲が狭まり、性能が制限されます。
そこで、SNELLでは、低ランク行列を用いて高ランク行列を構築するために、DyN [45] の考え方を応用し、距離関数を一般的なカーネル関数に拡張して、カーネルの観点からLoRAを検討します。
カーネル関数κ(x, x')は、陰的特徴マップϕを用いて、内積ϕ(x)⊤ϕ(x')として表現できます。LoRAのマージプロセスは、学習可能なパラメータAとBの行に線形カーネル関数κl(·, ·)を適用していると見なすことができます。
∆Wij = κl(Ai,·, Bj,·) = ϕl(Bj,·)ϕl(Ai,·)⊤= Bj,·A⊤
i,·,
κl(·, ·)をより複雑な非線形カーネル関数に置き換えることで、高次元空間Rdにおける関係を近似し、rよりも高いランクの行列を得ることができます。SNELLのマージされた適応行列は、以下のように表すことができます。
∆W = (κ(Ai,·, Bj,·))m×n = [ϕ(B1,·)⊤, ..., ϕ(Bn,·)⊤]⊤[ϕ(A1,·)⊤, ..., ϕ(An,·)⊤] = BϕA⊤
ϕ .
実際には、Aϕ∈Rn×dとBϕ∈Rm×dを明示的に計算する必要はありません。∆Wは、カーネル関数κを用いて、AとBから直接導出できます。
SNELLは、カーネルの観点からLoRAを拡張することで、低ランクの学習可能な行列に基づいて高ランクの適応行列を構築し、低いメモリ使用量で強力なスパースチューニングを実現します。