SHORTCUT MODELS: метод обучение диффузионных моделей генерации в 1 шаг.Shortcut…
Shortcut models – метод обучения диффузионных моделей, который позволяет генерировать изображения высокого качества за один или несколько шагов.
В основе shortcut models – идея обучать сеть с учетом не только текущего уровня шума, но и желаемого размера шага. Это позволяет модели “перепрыгивать” через этапы генерации.
Ключевым преимуществом данного подхода является его простота: shortcut models обучаются за один этап, используя одну сеть, в отличие от других методов ускорения выборки, которые полагаются на сложные схемы обучения с несколькими фазами, сетями или точной настройкой шедулера.
В процессе обучения shortcut models используются два типа целей loss function:
Совместная оптимизация этих целей дает возможность модели научиться создавать изображения, сохраняя согласованность при любом размере шага, включая генерацию за один шаг.
Метод применим к flow-matching и transformer-based типам моделей и RNN/LSTM-сетям.
Эксперименты, проведенные с DiT на наборах данных CelebA-HQ и ImageNet-256, подтверждают эффективность метода.
Shortcut models превосходят методы “end-to-end” обучения одношаговых генеративных моделей и конкурируют с двухэтапными методами дистилляции.
Практическая реализация shortcut models написана на JAX. Для локального запуска следует установить зависимости conda из файлов environment.yml и requirements.txt репозитория.
⚠️ Код поддерживает --model.sharding fsdp для полностью сегментированного параллелизма данных, если обучение проводится на multi-GPU или TPU.
⚠️ Чекпоинты и FID для тестовых датасетов CelebA и Imagenet доступны на Google-диске.
python train.py --model.hidden_size 768 --model.patch_size 2 --model.depth 12 --model.num_heads 12 --model.mlp_ratio 4
--dataset_name celebahq256 --fid_stats data/celeba256_fidstats_ours.npz --model.cfg_scale 0 --model.class_dropout_prob 1 --model.num_classes 1 --batch_size 64 --max_steps 410_000 --model.train_type shortcut
