推测性解码:实现 Whisper 推理速度提升两倍 [译]

Sanchit Gandhi 发表了一篇关于 Whisper 模型的最新研究进展。这是一个由 Open AI 开发的先进语音转录模型,能够在多种基准测试和不同的音频环境下展示出色的性能。其最新版本,名为 large-v3,已在开源语音转录模型领域名列前茅,特别是在英语转录方面表现卓越。这个模型还具备出色的多语言性能,在 Common Voice 15 的数据集中测试了 58 种语言,其中有 42 种语言的单词错误率低于 30%。

然而,尽管转录准确率令人印象深刻,但在处理速度上存在不足。即使使用了一些先进的推理优化技术,如 flash attention、半精度计算和音频分块处理,一个小时的音频片段在 16GB T4 GPU 上的处理时间仍需超过 6 分钟。

在这篇博客文章中,我们展示了如何应用“猜测式解码”(Speculative Decoding) 技术来减少 Whisper 语音识别模型的处理时间,实现了处理速度的 两倍提升,同时数学上保证了模型输出的 完全一致性。因此,这一方法可以无缝替代现有的 Whisper 处理流程,不仅保持了原有的准确性,还能实现处理速度的双倍快速提升。想要阅读更加精简的版本,其中包括较少的解释但涵盖所有代码,请参考附带的 Google Colab

推测性解码

推测性解码这一概念由 Google 的 Yaniv Leviathan 等人在 《通过推测性解码实现 Transformer 快速推理》中首次提出。这种方法基于一个假设:一个速度更快的 助理模型 往往可以生成与更大的 主模型 相同的 Token。

首先,助理模型以自回归式地生成一系列 候选 Token,表示为 (y^)1:N\left(\hat{\mathbf{\mathit{y}}}\right)_{1 : N}。例如,在下图中,助理模型生成了一个由 5 个 Token 组成的序列:The quick brown sock jumps

虽然这些候选 Token 能迅速生成,但它们有可能与主模型的预测不同。因此,在第二步中,这些候选 Token 需要提交给主模型进行验证。主模型接收这些候选 Token,并进行一次 单次前向传递。主模型的输出是 Token 序列每一步的“正确”Token,表示为 (y)1:N\left(\mathbf{\mathit{y}}\right)_{1 : N}

以上图为例,我们可以看到主模型预测的前三个 Token(The quick brown)与助理模型的预测相同。但是,助理模型的第四个候选 Token(sock)与主模型的正确 Token(fox)不一致。

我们知道直至第一个不匹配之前的所有候选 Token 都是正确的(The quick brown),因为它们与主模型的预测相符。但在第一个不匹配点之后,候选 Token 开始与主模型的实际预测偏离。于是,我们可以用主模型的正确 Token(fox)替换掉第一个错误的候选 Token(sock),并且放弃所有此后的预测 Token,因为它们已经产生了偏离。经过修正的序列 The quick brown fox 成为新的输入,重新提交给助理模型。

接下来,推理过程再次开始,助理模型生成新一轮的候选 Token,这些 Token 随后在主模型的一次前向传递中得到验证。

我们通过使用快速的助理模型进行自回归生成,仅利用较慢的主模型进行验证前向传递,大幅提升了解码过程的速度。此外,主模型的验证前向传递确保我们得到的输出与单独使用主模型时完全一致。这意味着推测性解码可以完美融入现有的 Whisper 系统中,保证输出质量不变。

要显著提高响应速度,助理模型的运行速度应远快于主模型,并且尽可能频繁地预测相同的 Token 分布。实际操作中,这两个特性需要权衡:模型运行越快,准确度通常越低。不过,由于大约 70-80% 的预测 Token 较为简单,这种权衡更倾向于选择速度更快的模型而非准确度更高的模型。因此,助理模型的速度至少应是主模型的 3 倍(越快越好),能够准确预测所有简单的 Token。而剩余 20-30% 较难的 Token 则由更大型的主模型进行验证。

选择助理模型时,唯一的限制是它必须与主模型共用相同的词汇库。也就是说,助理模型需要使用与主模型完全相同的 Tokenizer。因此,如果我们想用 Whisper 多语言版本进行推测性解码,如 large-v2(多语言版),我们需要选择一个多语言版本的 Whisper 助理模型,例如 tiny。而对于仅英语版本的 Whisper,如 medium.en,我们则需要选择相应的英语版助理模型,例如 tiny.en。目前,Whisper large-v3 是个例外,它是唯一一个扩大了词汇量的版本,因此与之前的 Whisper 版本不兼容。

现在我们已经理解了推测性解码(speculative decoding)的背景知识,我们可以开始探索它的实际应用了。在 🤗 Transformers 库里,推测性解码是作为“辅助生成(assisted generation)”这一推理策略来实施的。这一策略是如何运作的呢?简而言之,它可以在生成文本的过程中提供额外的辅助,以提高结果的准确性和相关性。要深入了解这一技术的具体实现,你可以阅读 Joao Gante 的精彩博文 辅助生成,其中详细介绍了这一策略。

英语语音转录

基线实现

我们的工作从对 Whisper large-v2 进行基准测试开始,目的是确定我们推理速度的基准值。主模型和相应处理器的加载可以通过 AutoModelForSpeechSeq2SeqAutoProcessor 这两个便捷的类来完成。我们用 float16 精度加载模型,为了缩短加载时间,我们设置 low_cpu_mem_usage=True。另外,我们确保模型以 safetensors 格式加载,这可以通过设置 use_safetensors=True 来实现。最后,我们设置 attn_implementation="sdpa" 参数,利用 PyTorch 的 SDPA 注意力机制,以便从 Flash Attention 获取速度提升:

import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v2"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)

接下来,我们将加载用于基准测试的英语语音转录数据集。我们选择了包含 73 个样本的 LibriSpeech ASR 验证干净数据集的一个小型数据集。这个数据集大约有 9MB,非常轻量,便于在设备上快速下载:

from datasets import load_dataset
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

在基准测试中,我们的重点是测量生成步骤所需的时间。因此,我们编写了一个简单的函数来执行这项测量。这个函数将返回解码的 tokens 和模型运行所需的时间:

import time
def generate_with_time(model, inputs, **kwargs):
start_time = time.time()
outputs = model.generate(**inputs, **kwargs)
generation_time = time.time() - start_time
return outputs, generation_time

现在,我们可以遍历数据集中的音频样本,累计总的生成时间:

from tqdm import tqdm
all_time = 0
predictions = []
references = []
for sample in tqdm(dataset):
audio = sample["audio"]
inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
inputs = inputs.to(device=device, dtype=torch.float16)
output, gen_time = generate_with_time(model, inputs)
all_time += gen_time
predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
references.append(processor.tokenizer._normalize(sample["text"]))
print(all_time)

输出:

100%|██████████| 73/73 [01:37<00:00, 1.33s/it]
72.99542546272278

我们发现,转写这 73 个音频样本用了 73 秒。现在,让我们来看一下这些预测的词错误率 (WER):

from evaluate import load
wer = load("wer")
print(wer.compute(predictions=predictions, references=references))

输出: 0.03507271171941831

Our final baseline number is 73 seconds for a WER of 3.5%.
### Speculative Decoding
Now let's load the assistant model for speculative decoding. In this example, we'll use a distilled variant of Whisper,
[distil-large-v2](https://huggingface.co/distil-whisper/distil-large-v2). The distilled model copies the entire encoder
from Whisper, but only 2 of the 32 decoder layers. As such, it runs 6x faster than Whisper, while performing to within
1% WER on out-of-distribution test sets. This makes it the perfect choice as an assistant model, since it has both
high transcription accuracy and fast generation 11.
Since Distil-Whisper uses exactly the same encoder as the Whisper model, we can share the encoder across the main and
assistant models. We then only have to load the 2-layer decoder from Distil-Whisper as a "decoder-only" model. We can do
this through the convenient [`AutoModelForCausalLM`](https://huggingface.co/docs/transformers/model_doc/auto#transformers.AutoModelForCausalLM)
auto class. In practice, this results in only an 8% increase to VRAM over using the main model alone.
```python
from transformers import AutoModelForCausalLM
assistant_model_id = "distil-whisper/distil-large-v2"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
assistant_model.to(device);

1^{1} 我们计划推出 Distil-Whisper 的升级版,这个版本在 Token 分布上的对齐性更强,将进一步提升推理解码的效率。欲了解最新消息,请关注 Distil-Whisper 项目仓库


我们可以修改一个函数,用于评估我们的推理解码性能。与之前的函数唯一不同之处在于,我们在调用 .generate 方法时加入了助理模型:

def assisted_generate_with_time(model, inputs, **kwargs):
start_time = time.time()
outputs = model.generate(**inputs, assistant_model=assistant_model, **kwargs)
generation_time = time.time() - start_time
return outputs, generation_time

现在,让我们利用 Distil-Whisper 作为 Whisper 的辅助模型,执行带推理解码的基准测试:

all_time = 0
predictions = []
references = []
for sample in tqdm(dataset):
audio = sample["audio"]
inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
inputs = inputs.to(device=device, dtype=torch.float16)
output, gen_time = assisted_generate_with_time(model, inputs)
all_time += gen_time
predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
references.append(processor.tokenizer._normalize(sample["text"]))
print(all_time)

输出:

100%|██████████| 73/73 [00:38<00:00, 1.88it/s]
32.69683289527893

利用推理解码,推断时间仅为 33 秒,比之前快了 2.2 倍!我们再来核对一下词错误率是否保持不变:

print(wer.compute(predictions=predictions, references=references))

输出:

0.03507271171941831

太好了!词错误率仍然是 3.5%,这意味着我们得到了与单独使用主模型时完全相同的结果。

推理解码也可以通过简单易用的 🤗 Transformers pipeline API 来实现。下面我们使用模型和处理器初始化 pipeline,然后用它来转录我们的示例数据集中的第一个样本。这个方法同样适用于转录任意长度的音频样本,甚至可以实现批量处理:

from transformers import pipeline
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=15,
batch_size=4,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])

输出:

Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.

如果你想要实现带有 Whisper 和 Distil-Whisper 的推理解码,可以参考 Distil-Whisper 模型说明,那里提供了一个从头到尾的代码示例,涵盖了我们在这个笔记本中介绍的所有推理步骤。

多语言语音转录

Distil-Whisper 是一款理想的英语语音转录辅助模型,它在字错误率(WER)上仅比原始的 Whisper 模型低 1%,而在处理短音频和长音频样本方面速度提升了 6 倍。然而,官方的 Distil-Whisper 模型仅支持英语,因此不能用于多语种的语音转录。

要想利用推测解码(speculative decoding)技术进行多语言语音转录,可以选择官方提供的多语言 Whisper 模型(例如 Whisper large-v2),或者选择经过针对性微调的 Whisper 模型版本。截至目前,Hugging Face 平台上已经有超过 5,000 个针对超过 100 种语言进行微调的 Whisper 模型(点击查看微调模型),这为我们在特定语言上选择高性能的 Whisper 辅助模型提供了丰富的选择。在这里,我们将使用官方提供的最小型多语言模型 Whisper tiny。您也可以根据需要尝试使用其他针对您的语言微调过的模型。

下面,我们将加载我们新的辅助模型 Whisper tiny 的权重。由于 Whisper tiny 的编码器(encoder)与 large-v2 版本的有所不同,这次我们需要同时加载编码器和解码器(decoder),我们使用 AutoModelForSpeechSeq2Seq 类来完成这一操作:

assistant_model_id = "openai/whisper-tiny"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
assistant_model.to(device);

对于我们的基准测试数据集,我们选用了 VoxPopuli 数据集中的 73 个荷兰语("nl")样本:

dataset = load_dataset("sanchit-gandhi/voxpopuli_dummy", "nl", split="validation")

太棒了!现在我们可以像之前一样,重新对我们的基准 Whisper large-v2 模型进行测试。唯一的改变是,我们在生成函数(generate function)中加入了语言和任务参数,以确保我们进行的是语音转录(speech transcription),而非语音翻译(speech translation)。推测解码技术既适用于语音转录,也适用于语音翻译任务。您只需根据下方所示调整任务参数即可:

all_time = 0
predictions = []
references = []
for sample in tqdm(dataset):
audio = sample["audio"]
inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
inputs = inputs.to(device=device, dtype=torch.float16)
output, gen_time = generate_with_time(model, inputs, language="nl", task="transcribe")
all_time += gen_time
predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
references.append(processor.tokenizer._normalize(sample["normalized_text"]))
wer_result = wer.compute(predictions=predictions, references=references)
print("Time:", all_time)
print("WER:", wer_result)

输出结果:

100%|██████████| 73/73 [02:05<00:00, 1.72s/it]
Time: 116.50992178916931
WER: 0.127190136275146

好的,我们得到了基准测试的结果:117 秒的处理时间和 12.8% 的字错误率。现在,让我们使用推测解码技术重新执行生成过程:

all_time = 0
predictions = []
references = []
for sample in tqdm(dataset):
audio = sample["audio"]
inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
inputs = inputs.to(device=device, dtype=torch.float16)
output, gen_time = assisted_generate_with_time(model, inputs, language="nl", task="transcribe")
all_time += gen_time
predictions.append(processor.batch_decode(output, skip_special_tokens=True, normalize=True)[0])
references.append(processor.tokenizer._normalize(sample["normalized_text"]))
wer_result = wer.compute(predictions=predictions, references=references)
print("Time:", all_time)
print("WER:", wer_result)

输出结果:

100%|██████████| 73/73 [01:08<00:00, 1.06it/s]
Time: 62.10229682922363
WER: 0.127190136275146

我们再次达到了 12.8% 的词错误率(WER),但仅在 62 秒的推理时间内完成,相比之前速度提升了 1.9 倍。考虑到助理模型加载的低成本以及能够产生完全一样的输出这一数学特性,推测性解码为现有 Whisper 语音转写系统提供了一个理想的替代方案。

高效推测性解码的策略

在这最后一节中,我们介绍两种确保推测性解码能够以最快速度进行推理的策略。

助理模型

我们的目标是选用一个至少比主模型快 3 倍的助理模型,且能至少准确转录 70-80% 的预测词元(Token),通常是样例中较简单的部分。如果你专注于某一特定语言的转录,一个有效的策略是训练两个不同规模的 Whisper 模型,并将其中一个作为另一个的辅助模型:

  • 首先,对 Whisper large-v3 进行微调,使其成为你的主模型。
  • 接着,对 Whisper large-v3 在同一数据集上进行蒸馏,使其成为一个快速的助理模型。

通过微调和蒸馏,可以提升你选定语言的主模型和助理模型在词错误率(WER)方面的性能,同时确保它们在词元分布上的最大对齐。关于 Whisper 微调的详细指南可以参考这里,关于蒸馏的信息可在这里查看。

批处理大小

值得一提的是,在使用推测性解码(一种预测下一步操作的技术)的过程中,当批处理大小为 1 时,我们可以获得最大的速度提升。在批量推测性解码中,整个批处理的所有候选 Token 都必须与验证 Token 完全一致,才会被认定为有效。如果批处理中某个特定位置的 Token 出现不一致,那么该位置之后的所有候选 Token 都会被放弃。这就是为什么推测性解码更适合较小的批处理规模。实际操作中,我们发现当批处理大小不超过 4 时,推测性解码可以有效提升处理速度。但一旦批处理大小超过 4,使用推测性解码的推理速度就会比仅用主模型慢。详细的实验结果,请参考 Distil-Whisper 论文中的 D.3 节。

结论

在这篇博文中,我们探讨了推测性解码在语音转录领域的应用,特别是在 Whisper 模型中的运用。我们展示了如何通过这种方法实现高达两倍的处理速度提升,同时从数学上保证输出结果与仅使用原始模型时一致。鉴于额外引入的助理模型带来的开销极低,我们鼓励您尝试在现有的 Whisper 处理流程中使用推测性解码,这能够确保您获得与原模型同样的转录结果。

致谢

本文由 Sanchit Gandhi 撰写。特别感谢 Patrick von PlatenPedro Cuenca 对本文提出的宝贵意见,以及 Joao Gante 在 🤗 Transformers 中实现辅助生成功能的杰出贡献。