Flamingo: мультимодальная модель от DeepMind

Flamingo – мультимодальная модель DeepMind, генерирующая текстовое описание фото, видео и звуков. Модель превосходит предыдущие state-of-the-art модели в 16 задачах, а ее особенностью является возможность обучаться на нескольких примерах.
Обычно для того, чтобы визуальная модель освоила новую задачу, она должна быть обучена на десятках тысяч примеров, специально размеченных для этой задачи. Если цель состоит в том, чтобы подсчитать и идентифицировать животных на изображении, нужно было бы собрать тысячи изображений животных с указанием их количества и вида. Этот процесс неэффективен, дорог и ресурсоемок, требует больших объемов размеченных данных и необходимости обучать новую модель каждый раз, когда она сталкивается с новой задачей.
Flamingo – few-shot-модель, которая решает данную проблему в широком спектре мультимодальных задач. На основе нескольких примеров пар визуальных входных данных и ожидаемых текстовых ответов, модели можно задать вопрос с новым изображением или видео, а затем сгенерировать ответ.

В 16 задачах c 4 парами примеров, на которых была протестирована Flamingo, модель все предыдущие state-of-the-art подходы. Во Flamingo объединяются большие языковые модели, визуальные представления (каждое из которых было предварительно обучено), и разработанные в DeepMind архитектурные компоненты между ними. Затем модель обучается на смеси дополнительных неразмеченных крупномасштабных мультимодальных данных из Интернета.
Модель Flamingo превосходит существующие модели, которые точно настроены и оптимизированы для каждой задачи и используют на несколько порядков больше данных, специфичных для задачи. Flamingo позволит неспециалистам быстро и легко использовать точные модели визуального языка для решения новых задач.
Установка:
pip install flamingo-pytorch
Полный рабочий пример с Flamingo + PaLM 🌴🦩🌴
from vit_pytorch.vit import ViT
from vit_pytorch.extractor import Extractor
vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
vit = Extractor(vit, return_embeddings_only = True)
# first take your trained image encoder and wrap it in an adapter that returns the image embeddings
# here we use the ViT from the vit-pytorch library
import torch
from flamingo_pytorch import FlamingoPaLM
# a PaLM language model, the 540 billion parameter model from google that shows signs of general intelligence
flamingo_palm = FlamingoPaLM(
num_tokens = 20000, # number of tokens
dim = 1024, # dimensions
depth = 12, # depth
heads = 8, # attention heads
dim_head = 64, # dimension per attention head
img_encoder = vit, # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
media_token_id = 3, # the token id representing the [media] or [image]
cross_attn_every = 3, # how often to cross attend
perceiver_num_latents = 64, # perceiver number of latents, should be smaller than the sequence length of the image tokens
perceiver_depth = 2 # perceiver resampler depth
)
# train your PaLM as usual
text = torch.randint(0, 20000, (2, 512))
palm_logits = flamingo_palm(text)
# after much training off the regular PaLM logits
# now you are ready to train Flamingo + PaLM
# by passing in images, it automatically freezes everything but the perceiver and cross attention blocks, as in the paper
dialogue = torch.randint(0, 20000, (4, 512))
images = torch.randn(4, 2, 3, 256, 256)
flamingo_logits = flamingo_palm(dialogue, images)
# do your usual cross entropy loss
Подробнее: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
Код: https://github.com/lucidrains/flamingo-pytorch