Core Concepts
A novel approach to enable usage of low-precision block floating point formats without compromising the resulting model accuracy, by exploiting the common channel-wise patterns exhibited by outliers in weights and activations.
Abstract
The paper focuses on the problem of efficient inference on extremely large-scale large language models (LLMs). The key challenge is the colossal shortage of dedicated hardware capable of efficient and fast processing of the involved compute and memory movement, especially due to the exploding growth in the lengths of the sequences being processed.
To address this, the authors propose a novel approach that enables the usage of low-precision block floating point (BFP) formats without compromising the resulting model accuracy. The key observation is that the inner product is invariant to synchronized reshuffling of the tensors being multiplied. The authors exploit the common channel-wise patterns exhibited by the outliers in weights and activations to rearrange them in such a way that their quantization quality is significantly improved.
Specifically, the authors sort the rows of the weight matrix Wk by their Euclidean norms before quantization. This ensures that the elements within each block have comparable magnitudes, avoiding the issue of outliers affecting the quantization accuracy of other elements in the same block. To compensate for this reshuffling, the authors also rearrange the columns of the query weight matrix Wq in the same order. This permutation happens at the compile time and has no impact on the inference latency.
The authors demonstrate the effectiveness of their approach on the Llama2-7B model, showing that their K-sort algorithm together with BFP12 storage allows for a 2x reduction in the memory footprint of the K-cache without significant degradation of the model's accuracy.
Stats
The paper reports the following key figures:
Baseline perplexity of Llama2-7B model on wikitext-2 dataset: 9.4881 (in FP16)
Perplexity with BFP12 quantization of keys and BFP16 quantization of queries:
Block size 128: 10.0861
Block size 64: 9.6061
Block size 32: 9.5196