Shortcuts

TorchMultimodal教程:微调FLAVA

Created On: Oct 27, 2022 | Last Updated: Apr 11, 2023 | Last Verified: Nov 05, 2024

多模态AI因其普遍特性最近变得非常流行,从图像字幕和视觉搜索等用例到最近基于文本生成图像的应用。TorchMultimodal是一个由PyTorch支持的库,包含构建模块和端到端示例,旨在启用和加速多模态研究。

在本教程中,我们将展示如何使用 TorchMultimodal库中的预训练的SoTA模型 FLAVA 在一个多模态任务上进行微调,即视觉问答 (VQA)。该模型由两个基于文本和图像的单模态Transformer编码器以及一个将两个嵌入结合的多模态编码器组成。它通过对比学习、图像文本匹配以及图像、文本和多模态掩码损失进行了预训练。

安装

在本教程中,我们将使用来自Hugging Face的TextVQA数据集和``bert tokenizer``。因此你需要安装datasets和transformers以及TorchMultimodal。

备注

如果在Google Colab中运行本教程,通过创建一个新单元格并运行以下命令来安装所需的包:

!pip install torchmultimodal-nightly
!pip install datasets
!pip install transformers

步骤

  1. 运行以下命令,下载Hugging Face数据集到你的电脑目录:

    wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz
    tar xf vocab.tar.gz
    

    备注

    如果在Google Colab中运行本教程,请在新单元格中运行这些命令,并在命令前加上感叹号(!)

  2. 对于本教程,我们将VQA视为一个分类任务,其中输入是图像和问题(文本),输出是一个答案类别。因此我们需要下载带有答案类别的词汇文件并创建答案到标签的映射。

    我们还从Hugging Face加载了 textvqa数据集,其中包含34602个训练样本(图像、问题和答案)。

我们发现有3997个答案类别,包括一个表示未知答案的类别。

with open("data/vocabs/answers_textvqa_more_than_1.txt") as f:
  vocab = f.readlines()

answer_to_idx = {}
for idx, entry in enumerate(vocab):
  answer_to_idx[entry.strip("\n")] = idx
print(len(vocab))
print(vocab[:5])

from datasets import load_dataset
dataset = load_dataset("textvqa")

让我们显示数据集中的一个示例条目:

import matplotlib.pyplot as plt
import numpy as np
idx = 5
print("Question: ", dataset["train"][idx]["question"])
print("Answers: " ,dataset["train"][idx]["answers"])
im = np.asarray(dataset["train"][idx]["image"].resize((500,500)))
plt.imshow(im)
plt.show()

3. Next, we write the transform function to convert the image and text into Tensors consumable by our model - For images, we use the transforms from torchvision to convert to Tensor and resize to uniform sizes - For text, we tokenize (and pad) them using the BertTokenizer from Hugging Face - For answers (i.e. labels), we take the most frequently occurring answer as the label to train with:

import torch
from torchvision import transforms
from collections import defaultdict
from transformers import BertTokenizer
from functools import partial

def transform(tokenizer, input):
  batch = {}
  image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])])
  image = image_transform(input["image"][0].convert("RGB"))
  batch["image"] = [image]

  tokenized=tokenizer(input["question"],return_tensors='pt',padding="max_length",max_length=512)
  batch.update(tokenized)


  ans_to_count = defaultdict(int)
  for ans in input["answers"][0]:
    ans_to_count[ans] += 1
  max_value = max(ans_to_count, key=ans_to_count.get)
  ans_idx = answer_to_idx.get(max_value,0)
  batch["answers"] = torch.as_tensor([ans_idx])
  return batch

tokenizer=BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=512)
transform=partial(transform,tokenizer)
dataset.set_transform(transform)

4. Finally, we import the flava_model_for_classification from torchmultimodal. It loads the pretrained FLAVA checkpoint by default and includes a classification head.

该模型的前向函数将图像通过视觉编码器,问题通过文本编码器。图像和问题嵌入然后通过多模态编码器。最终对应CLS标记的嵌入通过MLP头,最终提供每个可能答案的概率分布。

from torchmultimodal.models.flava.model import flava_model_for_classification
model = flava_model_for_classification(num_classes=len(vocab))

5. We put together the dataset and model in a toy training loop to demonstrate how to train the model for 3 iterations:

from torch import nn
BATCH_SIZE = 2
MAX_STEPS = 3
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE)
optimizer = torch.optim.AdamW(model.parameters())


epochs = 1
for _ in range(epochs):
  for idx, batch in enumerate(train_dataloader):
    optimizer.zero_grad()
    out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"])
    loss = out.loss
    loss.backward()
    optimizer.step()
    print(f"Loss at step {idx} = {loss}")
    if idx >= MAX_STEPS-1:
      break

总结

本教程介绍了如何使用TorchMultimodal中的FLAVA微调多模态任务的基础内容。另外,请查看库中的其他示例,如 MDETR ,这是一个用于目标检测的多模态模型,以及 Omnivore ,它适用于图像、视频和3D分类的多任务模型。

脚本总运行时间: (0 分钟 0.000 秒)

画廊由 Sphinx-Gallery 生成

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源