大模型1.58位极端量化

随着大型语言模型 (LLM) 的规模和复杂性不断增长,寻找降低其计算和能源成本的方法已成为一项关键挑战。

一种流行的解决方案是量化,其中参数的精度从标准的 16 位浮点 (FP16) 或 32 位浮点 (FP32) 降低到 8 位或 4 位等低位格式。虽然这种方法可以显着减少内存使用量并加快计算速度,但往往以牺牲准确性为代价。过度降低精度会导致模型丢失关键信息,从而导致性能下降。

BitNet 是一种特殊的 Transformer 架构,它仅用三个值表示每个参数:(-1, 0, 1),每个参数仅提供 1.58 ( log2(3) ) 位的极端量化。但是,它需要从头开始训练模型。

虽然结果令人印象深刻,但并不是每个人都有预算来预先训练 LLM。为了克服这一限制,我们探索了一些技巧,可以将现有模型微调到 1.58 位!继续阅读以了解如何操作!

1、BitNet概述

BitNet 是微软研究院推出的一种架构,它使用极端量化,每个参数仅用三个值表示:-1、0 和 1。这使得模型每个参数仅使用 1.58 位,从而大大降低了计算和内存要求。

该架构在执行矩阵乘法时使用 INT8 加法计算,而 LLaMA LLM 的加法和乘法运算则使用 FP16。

BitNet b1.58 的新计算范式。来源:BitNet 论文

这在理论上降低了能耗,与 Llama 基线相比,BitNet b1.58 节省了 71.4 倍的矩阵乘法算术运算能量。

BitNet b1.58 与 LLama 的能耗对比。来源:BitNet 论文

我们已成功使用 BitNet 架构对 Llama3 8B 模型进行微调,在下游任务中取得了出色的表现。我们开发的 8B 模型在 HF1BitLLM 组织下发布。其中两个模型在 10B 令牌上进行了微调,训练设置不同,而第三个模型在 100B 令牌上进行了微调。值得注意的是,我们的模型在 MMLU 基准测试中超越了 Llama 1 7B 模型。

1.1 如何与 Transformers 一起使用

为了将 BitNet 架构集成到 Transformers 中,我们引入了一种称为“bitnet”(PR)的新量化方法。该方法涉及用与 BitNet 架构兼容的专用 BitLinear 层替换标准 Linear 层,并对激活、权重解包和矩阵乘法进行适当的动态量化。

在 Transformers 中加载和测试模型非常简单,API 没有任何变化:

model = AutoModelForCausalLM.from_pretrained(
    "HF1BitLLM/Llama3-8B-1.58-100B-tokens",
    device_map="cuda",
    torch_dtype=torch.bfloat16
)    
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

input_text = "Daniel went back to the the the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:"

input_ids = tokenizer.encode(input_text, return_tensors="pt").cuda()
output = model.generate(input_ids, max_new_tokens=10)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)

使用此代码,一切都可以在后台无缝管理,因此无需担心额外的复杂性,只需安装最新版本的 transformers 即可。

要快速测试模型,请查看此笔记本

2、更深入地了解 BitNet?

BitNet 用称为 BitLinear 的专用层替换了多头注意力和前馈网络中的传统线性层,这些层使用三元精度(在初始版本中甚至是二进制)。我们在此项目中使用的 BitLinear 层使用三元精度(值为 -1、0 和 1)量化权重,并将激活量化为 8 位精度。我们在训练中使用的 BitLinear 实现与推理不同,我们将在下一节中看到。

三元精度训练的主要障碍是权重值是离散化的(通过 round() 函数),因此不可微分。BitLinear 用一个很好的技巧解决了这个问题:STE(直通估计器)。 STE 通过将梯度近似为 1(将 round() 视为等同于恒等函数)允许梯度流过不可微的舍入运算。另一种看待它的方式是,STE 不会在舍入步骤停止梯度,而是让梯度通过,就好像舍入从未发生过一样,从而能够使用基于梯度的标准优化技术更新权重。

具有 BitLinear 层的 BitNet 架构。来源:BitNet 论文 


2.1 训练

我们以全精度进行训练,但使用对称张量量化将权重量化为三元值。首先,我们计算权重矩阵绝对值的平均值并将其用作比例。然后,我们将权重除以比例,对值进行四舍五入,将它们限制在 -1 和 1 之间,最后重新调整它们以继续以全精度进行训练。

然后使用 absmax per token 量化将激活量化为指定的位宽(在我们的例子中为 8 位)(有关量化方法的全面介绍,请查看此文章)。这涉及将激活缩放到 8 位位宽的范围 [−128, 127]。量化公式为:

我们在量化激活之前应用层归一化(LN)来维持输出的方差:

其中 ϵ 是一个较小的数字,以防止溢出。

如前所述, round() 函数不可微。我们使用 detach() 作为技巧,在后向传递中实现可微分的直通估计器:

# Adapted from https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
import torch
import torch.nn as nn 
import torch.nn.functional as F

def activation_quant(x):
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127) / scale
    return y
 
def weight_quant(w):
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    u = (w * scale).round().clamp_(-1, 1) / scale
    return u

class BitLinear(nn.Linear):
    """
    Only for training
    """
    def forward(self, x):
        w = self.weight
        x_norm = LN(x)
        
        # A trick for implementing Straight−Through−Estimator (STE) using detach()
        x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
        w_quant = w + (weight_quant(w) - w).detach()
        
        # Perform quantized linear transformation
        y = F.linear(x_quant, w_quant)
        return y

2.2 推理

在推理过程中,我们只需将权重量化为三元值,而无需重新缩放。我们使用 8 位精度对激活应用相同的方法,然后使用高效内核执行矩阵乘法,然后除以权重和激活尺度。这应该会显著提高推理速度,特别是在优化硬件的情况下。您可以看到,重新缩放过程在训练过程中有所不同,因为矩阵乘法保留在 fp16/bf16/fp32 中以进行适当的训练。

# Adapted from https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
import torch
import torch.nn as nn 
import torch.nn.functional as F

def activation_quant_inference(x):
    x = LN(x)
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127)
    return y, scale
 
class BitLinear(nn.Linear):
    """
    Only for training
    """
    def forward(self, x):
        w = self.weight # weights here are already quantized to (-1, 0, 1)    
        w_scale = self.w_scale  
        x_quant, x_scale = activation_quant_inference(x)
        y = efficient_kernel(x_quant, w) / w_scale / x_scale
        return y

3、1.58bit 的预训练结果

在尝试微调之前,我们首先尝试通过预训练重现 BitNet 论文的结果。我们从一个小数据集、tinystories 和一个 Llama3 8B 模型开始。我们确认添加规范化函数(如论文中所述)可以提高性能。例如,经过 2000 步训练后,我们在验证集上的困惑度在未进行规范化的情况下等于 6.3,在进行规范化的情况下等于 5.9。在这两种情况下,训练都很稳定。

没有(蓝色)和有(橙色)层归一化的预训练图

虽然这种方法对于预训练来说看起来非常有趣,但只有少数机构能够负担得起在必要规模上进行这种训练。然而,已经存在大量强大的预训练模型,如果它们可以在预训练后转换为 1.58 位,那将非常有用。其他小组报告说,微调结果不如预训练结果那么好,所以我们开始调查,看看我们能否使 1.58 微调发挥作用。

4、1.58bit 的微调

当我们开始从预训练的 Llama3 8B 权重进行微调时,模型的表现略好,但没有我们预期的那么好。

注意:我们所有的实验都是使用 Nanotron 进行的。如果你有兴趣尝试 1.58bit 预训练或微调,可以查看此 PR
微调图与预训练图的比较

为了了解原因,我们尝试检查随机初始化模型和预训练模型的权重分布,以识别潜在问题。

左:随机权重分布。右:预训练llama3权重分布

两个分布的尺度值分别为:

左:随机权重尺度分布。右:预训练llama3权重尺度分布

显然,预训练模型从更多信息(尺度)开始,而随机初始化模型从几乎没有信息开始,并随着时间的推移不断增加信息。我们的结论是,从随机权重开始会为模型提供最少的初始信息,从而实现逐步学习过程,而在微调过程中,引入 BitLinear 层会使模型不堪重负,从而丢失所有先前信息。

为了改善微调结果,我们尝试了不同的技术。例如,我们尝试了每行和每列量化,而不是使用每个张量量化,以从 Llama 3 权重中保留更多信息。我们还尝试改变尺度的计算方式:我们不是仅仅将权重的平均绝对值作为尺度,而是将异常值的平均绝对值作为尺度(异常值是超过 k*mean_absolute_value 的值,其中 k 是我们在实验中尝试改变的常数),但我们没有注意到大的改进。

def scale_outliers(tensor, threshold_factor=1):
    mean_absolute_value = torch.mean(torch.abs(tensor))
    threshold = threshold_factor * mean_absolute_value
    outliers = tensor[torch.abs(tensor) > threshold]
    mean_outlier_value = torch.mean(torch.abs(outliers))
    return mean_outlier_value

def weight_quant_scaling(w):
    scale = 1.0 / scale_outliers(w).clamp_(min=1e-5)
    quantized_weights = (w * scale).round().clamp_(-1, 1) / scale
    return quantized_weights

我们观察到,随机权重和 Llama 3 权重都导致损失,损失值大约从 13 开始。这表明,在引入量化时,Llama 3 模型会丢失所有先前的信息。为了进一步研究模型在此过程中丢失了多少信息,我们尝试了每组量化。

作为健全性检查,我们首先将组大小设置为 1,这实际上意味着没有量化。在这种情况下,损失从 1.45 开始,与我们在正常微调期间看到的相同。但是,当我们将组大小增加到 2 时,损失跳升到 11 左右。这表明,即使最小组大小为 2,模型仍然会丢失几乎所有的信息。

为了解决这个问题,我们考虑了逐步引入量化的可能性,而不是突然将其应用于每个张量的权重和激活。为了实现这一点,我们实现了一个 lambda 值来控制这个过程:

lambda_ = ?
x_quant = x + lambda_ * (activation_quant(x) - x).detach()
w_quant = w + lambda_ * (weight_quant(w) - w).detach()

当 lambda 设置为 0 时,基本上不会发生量化,而当 lambda=1 时,则会应用完全量化。

我们最初测试了一些离散的 lambda 值,例如 0.25、0.5、0.75 和 1。然而,这种方法并没有带来任何显著的结果改善,主要是因为 lambda=0.25 已经足够高,损失一开始就很高。

使用 lambda = 0.25->0.5->0.75->1 进行微调图

因此,我们决定尝试使用根据训练步骤动态调整的 lambda 值。

lambda_ = training_step / total_training_steps

使用这个动态 lambda 值可以实现更好的损失收敛,但当 lambda 设置为 1 时,推理过程中的困惑度 (ppl) 结果仍然远远不能令人满意。我们意识到这可能是因为模型在 lambda=1 的情况下训练的时间不够长。为了解决这个问题,我们调整了 lambda 值以改进训练过程。

lambda_ = min(2 * training_step / total_training_steps, 1)

通过此配置,经过 2000 步后我们得到:

使用 lambda = min(2*training_step/total_training_steps, 1) 的微调图

我们的微调方法总体上显示出更好的收敛性。您可以观察到损失曲线在 1,000 步左右略有增加,这对应于我们开始接近 lambda=1 或完全量化时的情况。然而,在此之后,损失立即开始再次收敛,导致困惑度提高到大约 4。

尽管取得了这一进展,但当我们在 WikiText 数据集(而不是我们用于微调的 tinystories 数据集)上测试量化模型时,它显示出非常高的困惑度。这表明在特定数据集上以低位模式微调模型会导致其丢失大部分一般知识。这个问题可能是因为我们使用三元权重所追求的最小表示可能因数据集而异。为了解决这个问题,我们扩展了我们的训练过程以包括更大的 FineWeb-edu 数据集。我们保持了 lambda 值:

lambda_ = min(training_step/1000, 1)

我们之所以选择这个 lambda 值,是因为它似乎是预热模型的一个很好的起点。然后,我们在 FineWeb-edu 数据集上使用 1e-4 的学习率对模型进行了 5,000 步训练。训练涉及 200 万的批处理大小 (BS),总计 100 亿个标记。

找到合适的学习率和合适的衰减很有挑战性;这似乎是模型性能的关键因素。

在 Fineweb-edu 上使用预热量化进行微调的图

在 Fineweb-Edu 上进行微调后,WikiText 数据集上的困惑度达到了 12.2,考虑到我们只使用了 100 亿个标记,这相当令人印象深刻。考虑到数据量有限,其他评估指标也表现出色(见结果)。

我们还试图平滑 lambda 接近 1 时的急剧增长。为此,我们考虑使用 lambda 调度程序,这些调度程序最初呈指数增长,然后在接近 1 时趋于平稳。

def scheduler(step, total_steps, k):
    normalized_step = step / total_steps
    return 1 - (1 - normalized_step)**k

对于不同的 k 值,总预热步数为 1,我们有如下图所示的图表:

不同 k 值的指数调度程序

我们使用表现最佳的 1e-4 学习率进行了 4 次实验,测试 k 值在 [4、6、8、10] 中。

使用指数调度程序微调图

平滑效果很好,因为没有像线性调度程序那样出现峰值。但是,困惑度不是很好,保持在~15 左右,下游任务的性能也没有更好。

我们还注意到开始时的峰值,模型很难从中恢复过来。当 lambda = 0 时,基本上没有量化,因此损失从低处开始,大约为 ~2。但在第一步之后,出现了一个峰值,类似于线性调度程序发生的情况(如上图蓝色所示)。因此,我们尝试了一个不同的调度程序 - 一个 S 形调度程序 - 它开始时很慢,急剧上升到 1,然后在接近 1 时趋于平稳。

def sigmoid_scheduler(step, total_steps, k):
    # Sigmoid-like curve: slow start, fast middle, slow end
    normalized_step = step / total_steps
    return 1 / (1 + np.exp(-k * (normalized_step - 0.5)))

对于不同的 k 值,我们有以下曲线:

不同 k 值的 Sigmoid 调度器

这次我们进行了 5 次实验,k 取值范围为 [15, 20, 25, 40, 100]:

使用 S 型调度程序微调图

lambda 的急剧增加导致第 500 步左右不稳定,并且没有解决第一个发散问题。但是,对于 k=100,我们观察到下游任务有所改善(参见结果表),尽管困惑度仍然保持在 ~13.5 左右。尽管如此,它并没有显示出与线性调度程序相比明显的性能提升。

此外,我们使用随机权重和各种学习率从头开始尝试训练模型。这使我们能够将微调方法与传统预训练方法的有效性进行比较。

具有不同学习率的不同预训练图

从随机权重训练的模型中,没有一个表现优于我们的微调模型。我们用这些模型实现的最佳困惑度为 26,与我们的微调方法的结果相比,这还不够。

4.1 扩展到 100B 令牌!

我们将实验扩展到 1000 亿个令牌,看看我们能否达到 Llama 3 8B 的性能。我们进行了更长的训练运行,从使用线性调度程序的较短运行中表现最佳的检查点开始,并继续微调 45,000 步。我们尝试了不同的学习率,虽然该模型在某些指标上的表现与 Llama 3 模型接近,但平均而言,它仍然落后。

以下是我们在训练期间在各个检查点评估的一些指标示例:

训练期间对不同 lrs 的指标评估

平均分数如下:

不同 lrs 训练期间的平均评估

4.2 较小模型上的实验

在我们对 SmolLM 等较小模型的初步实验中,我们观察到预热量化技术并没有像在较大模型中那样产生那么多的改进。这表明预热量化的有效性可能与模型大小和复杂性更密切相关。

例如,这是 SmolLM 135M 模型的损失曲线,从一开始将预热量化与完全量化进行比较。有趣的是,曲线非常接近,并且产生的困惑度并没有显着差异。

Smoll LLm 微调实验,有无预热量化

4.3 结果与比较

与基线方法相比,BitNet 能够有效提供强大的性能,尤其是在较低位级别。根据论文,BitNet 的得分与 8 位模型相当,但推理成本明显较低。在 4 位模型的情况下,仅量化权重的方法优于量化权重和激活的方法,因为激活更难量化。然而,使用 1.58 位权重的 BitNet 超越了仅权重和权重和激活量化方法。

下表显示了 Llama3 8B 经过 10B 微调过程后各种指标的结果。这些结果与其他模型架构的结果进行了比较,以提供全面的性能概述(所有评估均使用 LightevalNanotron 格式模型上进行)

与 Llama 模型的指标比较:Linear 表示线性 lambda 调度程序,Sigmoid 表示 Sigmoid lambda 调度程序(在我们的例子中,k = 100)

在使用三元权重对 100 亿个 token 进行微调后,该模型表现出令人印象深刻的性能,尤其是与经过更广泛训练的其他模型相比。例如,它优于 Bitnet 7B 模型,后者是在 1000 亿个 token 的明显更大的数据集上训练的。此外,它的表现优于 FBI LLM(完全二值化 LLM),后者是在更庞大的 1.26 万亿个 token 上提炼的模型。尽管微调过程的规模相对较小,但这凸显了模型的效率和有效性。

对于 100B token 实验,我们拥有的最佳性能检查点如下:

与使用 100B 个 token 训练的 Llama 模型的指标比较

要复制这些结果,您可以查看此 PR,将模型转换为 nanotron 格式,解压权重(检查函数 unpack_weights),并使用 lighteval

请注意,即使模型是从 Instruct 调整模型微调而来的,它们仍然需要使用 Instruct 数据集进行微调。这些可以被视为基础模型。

5、自定义内核和基准测试

为了从 BitNet 低精度权重中获益,我们将它们打包成一个 int8 张量(这使得参数数量从 8B 增加到 2.8B!)。在推理过程中,必须在执行矩阵乘法之前解包这些权重。我们在 Cuda 和 Triton 中实现了自定义内核,以处理矩阵乘法过程中的即时解包。对于矩阵乘法本身,我们采用了缓存平铺矩阵乘法技术。为了完全掌握这种方法,让我们首先回顾一些 Cuda 编程基础知识。

5.1 基本 GPU 概念:线程、块和共享内存

在深入研究缓存平铺矩阵乘法之前,了解一些基本的 GPU 概念很重要:

  • 线程和块:GPU 同时执行数千个线程。这些线程被分组为块,每个块独立运行。网格由这些块组成,它代表整个问题空间。例如,在矩阵乘法中,每个线程可能负责计算输出矩阵的单个元素。
  • 共享内存:每个块可以访问有限数量的共享内存,这比全局内存(GPU 上的主内存)快得多。但是,共享内存的大小有限,并且由块内的所有线程共享。有效使用共享内存是提高 GPU 程序性能的关键。

5.2 矩阵乘法中的挑战

在 GPU 上实现矩阵乘法的简单方法可能涉及每个线程通过直接从全局内存读取必要元素来计算结果矩阵的单个元素。但是,这种方法可能效率低下,原因如下:

  • 内存带宽:与 GPU 核心执行计算的速度相比,访问全局内存相对较慢。如果每个线程直接从全局内存读取矩阵元素,则内存访问时间可能成为瓶颈。
  • 冗余数据访问:在矩阵乘法中,输入矩阵的许多元素被多次使用。如果每个线程都独立地从全局内存中获取所需数据,则相同的数据可能会多次加载到 GPU 中,从而导致效率低下。例如,如果每个线程都用于计算输出矩阵中的单个元素,则负责计算位置 (i, j) 处元素的线程将需要从全局内存中加载矩阵 A 的第 i 行和矩阵 B 的第 j 列。但是,其他线程(例如计算位置 (i+1, j) 处元素的线程)无法重用此数据,并且必须再次从全局内存中重新加载相同的第 j 列。

5.3 平铺的理念

平铺是一种用于解决这些挑战的技术,它主要用于 FlashAttention 中以提高内核的效率。基本理念是将矩阵划分为较小的子矩阵(称为图块),这些子矩阵可以放入 GPU 的共享内存中。计算不是一次性计算整个输出矩阵,而是分解为较小的部分,然后逐个图块进行处理。

在矩阵乘法的上下文中,这意味着将矩阵 A 和 B 分成块(图块),将这些图块加载到共享内存中,然后对这些较小的块执行乘法。这种方法允许线程重用存储在快速共享内存中的数据,从而减少重复访问全局内存的需要。

工作原理如下:

  • 将图块加载到共享内存中:每个线程块协同将矩阵 A 的一个图块和矩阵 B 的一个对应图块从全局内存加载到共享内存中。此操作对每个图块执行一次,然后块中的线程多次重用该图块。
  • 计算部分乘积:将图块加载到共享内存后,每个线程都会计算部分乘积。由于块中的所有线程都在共享内存中处理相同的图块,因此它们可以有效地重用数据而无需额外的全局内存访问。
  • 累积结果:计算一个图块的部分乘积后,线程将矩阵 A 和 B 的下一个图块加载到共享内存中并重复该过程。结果累积在寄存器(或本地内存)中,一旦所有图块都已处理,输出矩阵元素的最终值就会写回到全局内存中。
平铺矩阵乘法图示。来源:github


在实现缓存平铺矩阵乘法时,需要考虑几个因素:

  • 平铺大小:应选择平铺的大小以平衡可放入共享内存的数据量和全局内存访问次数之间的权衡。
  • 内存合并:全局内存访问被合并,这意味着相邻线程访问相邻的内存位置。
  • 占用率:应选择每个块的线程数和网格中的块数以确保高占用率,这意味着在 GPU 上拥有尽可能多的活动 Warp(一个 Warp 是一组 32 个线程)以隐藏内存延迟。

5.4 Triton 内核

这是我们进行基准测试的 triton 内核:

@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn, 
        stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  
        GROUP_SIZE_M: tl.constexpr,
):

    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)

    for i in range(4) : 
        b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
        for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K) ):
            k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j 

            # BLOCK_SIZE_K must be a divisor of K / 4 
            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
            b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K // 4 - j * BLOCK_SIZE_K, other=0)
            mask = 3<<(2*i)
            b = ((b_uint8 & mask) >> (2*i))

            # We accumulate the tiles along the K dimension.
            tensor_full = tl.full((1,), 1, dtype=tl.int8)

            accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32)

            a_ptrs += BLOCK_SIZE_K * stride_ak
            b_ptrs += BLOCK_SIZE_K * stride_bk

    c = accumulator

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


def matmul(a, b):
    assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    _, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c

5.5 代码分解

1、确定图块位置

内核首先确定每个线程块负责输出矩阵的哪个图块(块):

  • pid 是每个线程块的唯一标识符,使用 tl.program_id(axis=0) 获取。
  • 网格被分成线程块组(GROUP_SIZE_M)。每个组处理输出矩阵的一部分。
  • pid_m 和 pid_n 分别是 M 和 N 维度中图块的坐标。
  • 计算偏移量(offs_am、offs_bn、offs_k)以确定块中的每个线程将处理矩阵 A 和 B 的哪些元素

2、加载和计算图块

内核使用循环以 BLOCK_SIZE_K 的块形式迭代 K 维度。对于每个块:

  • 加载图块:从全局内存加载矩阵 A 和 B 中的图块。
  • 解包矩阵 B:内核假设矩阵 B 用 int8 值打包,这意味着每个元素实际上代表打包成一个字节的四个较小的值。解包发生在循环内:
  • b_uint8 从全局内存加载为打包的 int8。
  • 每个打包值都被解包以获取用于计算的实际权重值。
  • 点积:内核计算从 A 和 B 加载的图块的点积,并将结果累积在累加器中。累加器存储输出矩阵 C 图块的部分结果。

3、存储结果

在处理完 K 维度上的所有图块后,存储在累加器中的最终结果将转换为 float16 并写回全局内存中矩阵 C 的相应图块。写入过程使用掩码尊重内存边界,以确保只写入有效元素。

有关代码的更详细说明,请查看此 PR

5.6 基准测试

我们根据使用 @torch.compile 解压权重的方法对我们的内核进行了基准测试,然后以 BF16 精度执行 matmul,发现两种方法都实现了大致相同的性能。为了确保准确的基准测试,我们在 2000 次迭代中执行了 matmul 操作,并平均了最后 1000 次迭代所花费的时间,以消除与初始加载或编译相关的任何低效率。下面是显示基准测试结果的图表。我们还测试了各种矩阵大小,x 轴表示对数刻度上的乘法次数,y 轴表示平均时间(以毫秒为单位)。

Triton 内核与 torch.compile 的比较

我们还尝试使用 BitBlas,这是一个旨在执行混合精度矩阵运算的软件库。它允许以较低精度格式(如 INT8、INT4 甚至 INT2)而不是传统的 FP32 或 FP16 格式进行计算,从而帮助优化这些运算。

基准测试结果令人鼓舞,因为 BitBlas 在低精度方面的表现优于我们的自定义内核和 Torch 的 matmul 函数,如图所示。

Bitblas 基准测试

然而,在模型加载期间,BitBlas 需要编译适合权重矩阵形状的内核并将其存储在本地数据库中,这会增加初始加载时间。

6、结束语

总之,随着 LLM 的不断扩展,通过量化减少其计算需求至关重要。本博客探讨了使用三元权重的 1.58 位量化方法。虽然 1.58 位预训练模型需要大量资源,但我们已经证明,通过一些技巧,可以将现有模型微调到这种精度水平,在不牺牲准确性的情况下实现高效的性能。通过专用内核优化推理速度,BitNet 为使 LLM 更加实用和可扩展开辟了新的可能性。


原文链接:Fine-tuning LLMs to 1.58bit: extreme quantization made easy

BimAnt翻译整理,转载请标明出处