Flamingo: мультимодальная модель от DeepMind

Flamingo – мультимодальная модель DeepMind, генерирующая текстовое описание фото, видео и звуков. Модель превосходит предыдущие state-of-the-art модели в 16 задачах, а ее особенностью является возможность обучаться на нескольких примерах.

Обычно для того, чтобы визуальная модель освоила новую задачу, она должна быть обучена на десятках тысяч примеров, специально размеченных для этой задачи. Если цель состоит в том, чтобы подсчитать и идентифицировать животных на изображении, нужно было бы собрать тысячи изображений животных с указанием их количества и вида. Этот процесс неэффективен, дорог и ресурсоемок, требует больших объемов размеченных данных и необходимости обучать новую модель каждый раз, когда она сталкивается с новой задачей.

Flamingo – few-shot-модель, которая решает данную проблему в широком спектре мультимодальных задач. На основе нескольких примеров пар визуальных входных данных и ожидаемых текстовых ответов, модели можно задать вопрос с новым изображением или видео, а затем сгенерировать ответ.

В 16 задачах c 4 парами примеров, на которых была протестирована Flamingo, модель все предыдущие state-of-the-art подходы. Во Flamingo объединяются большие языковые модели, визуальные представления (каждое из которых было предварительно обучено), и разработанные в DeepMind архитектурные компоненты между ними. Затем модель обучается на смеси дополнительных неразмеченных крупномасштабных мультимодальных данных из Интернета.

Модель Flamingo превосходит существующие модели, которые точно настроены и оптимизированы для каждой задачи и используют на несколько порядков больше данных, специфичных для задачи. Flamingo позволит неспециалистам быстро и легко использовать точные модели визуального языка для решения новых задач.

Установка:

pip install flamingo-pytorch

Полный рабочий пример с Flamingo + PaLM 🌴🦩🌴

from vit_pytorch.vit import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

vit = Extractor(vit, return_embeddings_only = True)

# first take your trained image encoder and wrap it in an adapter that returns the image embeddings
# here we use the ViT from the vit-pytorch library

import torch
from flamingo_pytorch import FlamingoPaLM

# a PaLM language model, the 540 billion parameter model from google that shows signs of general intelligence

flamingo_palm = FlamingoPaLM(
    num_tokens = 20000,          # number of tokens
    dim = 1024,                  # dimensions
    depth = 12,                  # depth
    heads = 8,                   # attention heads
    dim_head = 64,               # dimension per attention head
    img_encoder = vit,           # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
    media_token_id = 3,          # the token id representing the [media] or [image]
    cross_attn_every = 3,        # how often to cross attend
    perceiver_num_latents = 64,  # perceiver number of latents, should be smaller than the sequence length of the image tokens
    perceiver_depth = 2          # perceiver resampler depth
)

# train your PaLM as usual

text = torch.randint(0, 20000, (2, 512))

palm_logits = flamingo_palm(text)

# after much training off the regular PaLM logits
# now you are ready to train Flamingo + PaLM
# by passing in images, it automatically freezes everything but the perceiver and cross attention blocks, as in the paper

dialogue = torch.randint(0, 20000, (4, 512))
images = torch.randn(4, 2, 3, 256, 256)

flamingo_logits = flamingo_palm(dialogue, images)

# do your usual cross entropy loss

Подробнее: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model

Код: https://github.com/lucidrains/flamingo-pytorch

Ответить