SALSA: Стабильная адаптация линейного поиска Armijo.

SALSA (Stable Armijo Line Search Adaptation) — метод, разработанный для оптимизации Learning Rate (LR) во время обучения. 

Основная концепция метода построена вокруг выполнения линейного поиска для определения наилучшего возможного LR для каждого шага обучения, что дает быструю сходимость и улучшенное обобщение.

SALSA: Стабильная адаптация линейного поиска Armijo.

Чтобы уменьшить вычислительную нагрузку, Salsa предлагает пошаговый миниатюрный линейный поиск. В нем LR постепенно увеличивается с каждым шагом, а критерий линейного поиска постоянно переоценивается. 

SALSA: Стабильная адаптация линейного поиска Armijo.

Дополнительно, Salsa включает экспоненциальное сглаживание в процесс линейного поиска и устанавливает два экспоненциальных скользящих средних для скорости обучения. Это помогает стабилизировать оптимизацию и уменьшить нестабильность от мини-пакетирования.

Экспериментальные результаты показывают, что Salsa превосходит другие методы оптимизации: 50% сокращение final loss и 1,25 average rank в языковых и графических задачах. 

SALSA: Стабильная адаптация линейного поиска Armijo.

Вычислительные издержки Salsa всего на 3% выше, чем у базового LR метода, что можно воспринимать как незначительным увеличением, учитывая показатели производительности. Salsa достаточно универсален, чтобы использоваться с различными оптимизаторами, и особенно эффективен при обучении современных архитектур, которые чувствительны к скорости обучения.

SALSA: Стабильная адаптация линейного поиска Armijo.

▶️Локальный запуск:

# Clone repository:

git clone https://github.com/TheMody/No-learning-rates-needed-Introducing-SALSA-Stable-Armijo-Line-Search-Adaptation.git

# Create & activate env:

conda env create -f environment.yml

conda activate sls3

# Install dependencies:

pip install pytorch numpy transformers datasets tensorflow-datasets wandb

# NOTE: custom optimizer is in \salsa\SaLSA.py,comparison version are in \salsa\adam_sls.py:

from salsa.SaLSA import SaLSA

self.optimizer = SaLSA(model.parameters())

# NOTE: typical pytorch forward pass needs to be changed to:

def closure(backwards = False):

    y_pred = model(x)

    loss = criterion(y_pred, y)

    if backwards: loss.backward()

    return loss

optimizer.zero_grad()

loss = optimizer.step(closure = closure)

📌Лицензирование :  MIT License

🟡Arxiv

🟡Датасет Cifar-10

🟡Youtube video

🖥Github [ Stars: 11 | Issues: 0 | Forks: 0]

@ai_machinelearning_big_data

#AI  #LLM #ML #Train #SALSA

+1
0
+1
0
+1
0
+1
0
+1
0

Ответить

Ваш адрес email не будет опубликован. Обязательные поля помечены *