视觉语言模型详解
视觉语言模型(Visual Language Models)是可以同时从图像和文本中学习以处理许多任务的模型,从视觉问答到图像字幕。在这篇文章中,我们将介绍视觉语言模型的主要组成部分:概述,了解它们的工作原理,弄清楚如何找到合适的模型,如何使用它们进行推理以及如何使用新版 trl 轻松微调它们!
1、什么是视觉语言模型?
视觉语言模型被广泛定义为可以从图像和文本中学习的多模态模型。它们是一种生成式模型,可以接受图像和文本输入并生成文本输出。
大型视觉语言模型具有良好的零样本能力,具有良好的泛化能力,并且可以处理多种类型的图像,包括文档、网页等。用例包括谈论图像、通过指令进行图像识别、视觉问答、文档理解、图像字幕等。
一些视觉语言模型还可以捕获图像中的空间属性。当系统提示检测或分割特定主题时,这些模型可以输出边界框或分割掩码,也可以定位不同的实体或回答有关其相对或绝对位置的问题。
现有的大型视觉语言模型集、它们所训练的数据、它们对图像的编码方式以及它们的能力都存在很大差异。
2、开源视觉语言模型概述
Hugging Face Hub 上有许多开源视觉语言模型。下表列出了一些最突出的模型。
- 有基础模型,也有针对聊天进行微调的模型,可用于对话模式。
- 其中一些模型具有称为“grounding”的功能,可减少模型幻觉。
- 除非另有说明,所有模型均使用英语进行训练。
Model | Permissive License | Model Size | Image Resolution | Additional Capabilities |
---|---|---|---|---|
LLaVA 1.6 (Hermes 34B) | ✅ | 34B | 672x672 | |
deepseek-vl-7b-base | ✅ | 7B | 384x384 | |
DeepSeek-VL-Chat | ✅ | 7B | 384x384 | Chat |
moondream2 | ✅ | ~2B | 378x378 | |
CogVLM-base | ✅ | 17B | 490x490 | |
CogVLM-Chat | ✅ | 17B | 490x490 | Grounding, chat |
Fuyu-8B | ❌ | 8B | 300x300 | Text detection within image |
KOSMOS-2 | ✅ | ~2B | 224x224 | Grounding, zero-shot object detection |
Qwen-VL | ✅ | 4B | 448x448 | Zero-shot object detection |
Qwen-VL-Chat | ✅ | 4B | 448x448 | Chat |
Yi-VL-34B | ✅ | 34B | 448x448 | Bilingual (English, Chinese) |
3、寻找合适的视觉语言模型
有很多方法可以为你的用例选择最合适的模型。
Vision Arena 是一个完全基于模型输出匿名投票的排行榜,并不断更新。在这个竞技场中,用户输入一张图片和一个提示,然后匿名抽取两个不同模型的输出,然后用户可以选择他们喜欢的输出。这样,排行榜就完全基于人类的偏好构建。
Open VLM Leaderboard 是另一个排行榜,其中根据这些指标和平均分数对各种视觉语言模型进行排名。你还可以根据模型大小、专有或开源许可证筛选模型,并根据不同的指标进行排名。
VLMEvalKit 是一个工具包,用于在支持 Open VLM Leaderboard 的视觉语言模型上运行基准测试。另一个评估套件是 LMMS-Eval,它提供了一个标准命令行界面,可以使用托管在 Hugging Face Hub 上的数据集来评估你选择的 Hugging Face 模型,如下所示:
accelerate launch --num_processes=8 -m lmms_eval --model llava --model_args pretrained="liuhaotian/llava-v1.5-7b" --tasks mme,mmbench_en --batch_size 1 --log_samples --log_samples_suffix llava_v1.5_mme_mmbenchen --output_path ./logs/
Vision Arena 和 Open VLM Leaderbard 都仅限于提交给它们的模型,并且需要更新才能添加新模型。如果你想查找其他模型,可以浏览 Hub 中图像-文本-文本任务下的模型。
你可能会在排行榜中遇到不同的用于评估视觉语言模型的基准。我们将介绍其中的一些。
- MMMU
面向专家 AGI 的大规模多学科多模态理解和推理基准 (MMMU) 是评估视觉语言模型的最全面的基准。它包含 11.5K 个多模态挑战,需要跨艺术和工程等不同学科的大学水平学科知识和推理能力。
- MMBench
MMBench 是一个评估基准,包含 3000 个单选题,涉及 20 种不同的技能,包括 OCR、对象定位等。该论文还介绍了一种名为 CircularEval 的评估策略,其中问题的答案选项以不同的组合进行打乱,并且模型有望每次都给出正确答案。还有其他跨不同领域的更具体的基准,包括 MathVista(视觉数学推理)、AI2D(图表理解)、ScienceQA(科学问答)和 OCRBench(文档理解)。
4、技术细节
有多种方法可以预训练视觉语言模型。主要技巧是统一图像和文本表示并将其提供给文本解码器进行生成。最常见和最突出的模型通常由图像编码器、用于对齐图像和文本表示的嵌入投影器(通常是密集神经网络)和按此顺序堆叠的文本解码器组成。至于训练部分,不同的模型一直遵循不同的方法。
例如,LLaVA 由 CLIP 图像编码器、多模态投影器和 Vicuna 文本解码器组成。作者将图像和标题的数据集输入 GPT-4,并生成与标题和图像相关的问题。作者冻结了图像编码器和文本解码器,只训练了多模态投影仪,通过输入模型图像和生成的问题并将模型输出与真实标题进行比较来对齐图像和文本特征。在投影器预训练之后,他们保持图像编码器冻结,解冻文本解码器,并使用解码器训练投影仪。这种预训练和微调方式是训练视觉语言模型最常见的方式。
另一个例子是 KOSMOS-2,作者选择对模型进行端到端的全面训练,与 LLaVA 类预训练相比,这在计算上非常昂贵。作者后来进行了纯语言指令微调以对齐模型。还有一个例子是 Fuyu-8B,它甚至没有图像编码器。相反,图像块直接输入到投影层,然后序列经过自回归解码器。
大多数情况下,你不需要预先训练视觉语言模型,因为可以使用现有模型之一,也可以根据自己的用例对其进行微调。接下来我们将介绍如何使用 transformer 使用这些模型,以及如何使用 SFTTrainer 进行微调。
5、用transformer 执行视觉语言模型
你可以使用 LlavaNext 模型通过 Llava 进行推断,如下所示。
让我们先初始化模型和处理器:
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained(
"llava-hf/llava-v1.6-mistral-7b-hf",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
model.to(device)
现在我们将图像和文本提示传递给处理器,然后将处理后的输入传递给生成。请注意,每个模型都使用自己的提示模板,请小心使用正确的模板以避免性能下降。
from PIL import Image
import requests
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
inputs = processor(prompt, image, return_tensors="pt").to(device)
output = model.generate(**inputs, max_new_tokens=100)
调用解码来解码输出标记:
print(processor.decode(output[0], skip_special_tokens=True))
6、使用 TRL 微调视觉语言模型
我们很高兴地宣布,TRL 的 SFTTrainer 现在包括对视觉语言模型的实验性支持!我们在此提供了一个示例,说明如何使用包含 260k 个图像对话对的 llava-instruct 数据集在 Llava 1.5 VLM 上执行 SFT。该数据集包含格式化为消息序列的用户助手交互。例如,每个对话都与用户询问问题的图像配对。
要使用实验性的 VLM 训练支持,你必须使用 pip install -U trl
安装最新版本的 TRL。完整的示例脚本可以在此处找到。
from trl.commands.cli_utils import SftScriptArguments, TrlParser
parser = TrlParser((SftScriptArguments, TrainingArguments))
args, training_args = parser.parse_args_and_config()
初始化聊天模板,进行指令微调:
LLAVA_CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""
我们现在将初始化我们的模型和标记器:
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration
import torch
model_id = "llava-hf/llava-1.5-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer = tokenizer
model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16)
让我们创建一个数据整理器来组合文本和图像对:
class LLavaDataCollator:
def __init__(self, processor):
self.processor = processor
def __call__(self, examples):
texts = []
images = []
for example in examples:
messages = example["messages"]
text = self.processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
images.append(example["images"][0])
batch = self.processor(texts, images, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
if self.processor.tokenizer.pad_token_id is not None:
labels[labels == self.processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch
data_collator = LLavaDataCollator(processor)
加载我们的数据集:
from datasets import load_dataset
raw_datasets = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft")
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
初始化 SFTTrainer,传入模型、数据集分割、PEFT 配置和数据整理器并调用 train()
。要将我们的最终检查点推送到 Hub,请调用 push_to_hub()
。
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
dataset_text_field="text", # need a dummy field
tokenizer=tokenizer,
data_collator=data_collator,
dataset_kwargs={"skip_prepare_dataset": True},
)
trainer.train()
保存模型并推送至 Hugging Face Hub。
trainer.save_model(training_args.output_dir)
trainer.push_to_hub()
可以在此处找到训练好的模型。
原文链接:Vision Language Models Explained
BimAnt翻译整理,转载请标明出处