Как сгенерировать сообщение о фиксации git с помощью искусственного интеллекта?

Почему не используют имеющиеся решения? Все они используют ChatGPT. Но у меня закончились кредиты 😉
Конечно, я хочу чему-то научиться!

Как сгенерировать сообщение о фиксации git?

Git позволяет создавать крючки. Давайте воспользуемся глобальным. Глобальные крючки работают без модификации каждого git-репо.

Создайте каталог для крючков:

$ mkdir ~/.config/git/hooks/

Пусть git знает, где находятся крючки:

$ git config core.hooksPath ~/.config/git/hooks/

Короче говоря, prepare-commit-msg – это то, что нам нужно. В качестве первого параметра передается файл, который нам нужно обновить.
Создадим простой скрипт:

#!/bin/sh

echo "Fancy commit message" > $1

Сделайте его исполняемым:

$ chmod +z ~/.confog/git/hooks/prepare-commit-msg

Работает ли это? Давайте зафиксируем что-нибудь … Да, у нас есть сообщение в конце сообщения о фиксации.

Давайте сгенерируем что-нибудь:

Формирование сообщения о фиксации

Давайте создадим что-то, что будет работать в автономном режиме. ИИ? Да, давайте использовать искусственный интеллект!

Нам ведь нужна модель?

Давайте посмотрим на “обнимающееся лицо“!

Вот он: https://huggingface.co/mamiksik/T5-commit-message-generation, но там нет документации 🙁
Но если вы посмотрите глубже, то найдете https://huggingface.co/spaces/mamiksik/commit-message-generator.

Мы можем использовать этот https://huggingface.co/spaces/mamiksik/commit-message-generator/blob/main/app.py с небольшими изменениями.

Поскольку в качестве хука мы можем использовать любой shell-скрипт, давайте воспользуемся python.

Давайте посмотрим, что там есть:

import re

import gradio as gr
import torch
from transformers import T5ForConditionalGeneration, RobertaTokenizer


tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01")
model = T5ForConditionalGeneration.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01")

def parse_files(patch):
    accumulator = []
    lines = patch.splitlines()

    filename_before = None
    for line in lines:
        if line.startswith("index") or line.startswith("diff"):
            continue
        if line.startswith("---"):
            filename_before = line.split(" ", 1)[1][1:]
            continue

        if line.startswith("+++"):
            filename_after = line.split(" ", 1)[1][1:]

            if filename_before == filename_after:
                accumulator.append(f"<ide><path>{filename_before}")
            else:
                accumulator.append(f"<add><path>{filename_after}")
                accumulator.append(f"<del><path>{filename_before}")
            continue

        line = re.sub("@@[^@@]*@@", "", line)
        if len(line) == 0:
            continue

        if line[0] == "+":
            line = line.replace("+", "<add>", 1)
        elif line[0] == "-":
            line = line.replace("-", "<del>", 1)
        else:
            line = f"<ide>{line}"

        accumulator.append(line)

    return '\n'.join(accumulator)


def predict(patch, max_length, min_length, num_beams, prediction_count):
    input_text = parse_files(patch)
    with torch.no_grad():
        token_count = tokenizer(input_text, return_tensors="pt").input_ids.shape[1]

        input_ids = tokenizer(
            input_text,
            truncation=True,
            padding=True,
            return_tensors="pt",
        ).input_ids

        outputs = model.generate(
            input_ids,
            max_length=max_length,
            min_length=min_length,
            num_beams=num_beams,
            num_return_sequences=prediction_count,
        )

    result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return token_count, input_text, {k: 0 for k in result}


iface = gr.Interface(fn=predict, inputs=[
    gr.Textbox(label="Patch (as generated by git diff)"),
    gr.Slider(1, 128, value=40, label="Max message length"),
    gr.Slider(1, 128, value=5, label="Min message length"),
    gr.Slider(1, 10, value=7, label="Number of beams"),
    gr.Slider(1, 15, value=5, label="Number of predictions"),
], outputs=[
    gr.Textbox(label="Token count"),
    gr.Textbox(label="Parsed patch"),
    gr.Label(label="Predictions")
], examples=[
["""
diff --git a/.github/workflows/pylint.yml b/.github/workflows/codestyle_checks.yml
similarity index 86%
rename from .github/workflows/pylint.yml
rename to .github/workflows/codestyle_checks.yml
index a5d5c4d9..8cbf9713 100644
--- a/.github/workflows/pylint.yml
+++ b/.github/workflows/codestyle_checks.yml
@@ -20,3 +20,6 @@ jobs:
     - name: Analysing the code with pylint
       run: |
         pylint --rcfile=.pylintrc webapp core
+    - name: Analysing the code with flake8
+      run: |
+        flake8
""", 40, 5, 7, 5]
]
)

if __name__ == "__main__":
    iface.launch()

Все, что нам нужно, здесь! Нам нужно:

  • получение файла gitmessage для обновления
  • получить git diff
  • использовать текущий скрипт для прогнозирования
  • добавлять сообщение о фиксации в файл gitmessage

Файл, который нам нужно обновить, передается в качестве первого параметра, поэтому

import sys

sys.argv[1]

Это было просто.

Получение git diff

import subprocess

subprocess.run(['git', 'diff', '--cached'], capture_output=True).stdout.decode('utf-8')

Легко и просто!

Использовать текущий сценарий для прогнозирования

max_message = 40
min_message = 5
num_beams = 10
num_predictions = 1

msg = predict(diff, max_message, min_message, num_beams, num_predictions)

Добавьте наше сообщение в файл gitmessage

with open(sys.argv[1], 'r+') as f:
    content = f.read()
    f.seek(0)
    f.write(msg + '\n' + content)
    f.close()

Вот так. После небольших доработок это наш окончательный вариант сценария.

#!/usr/bin/env python
print("Generating commit message", end="", flush=True)

import sys
import re
import subprocess
import torch
from transformers import T5ForConditionalGeneration, RobertaTokenizer

def parse_files(patch):
    accumulator = []
    lines = patch.splitlines()

    filename_before = None
    for line in lines:
        print(".", end="", flush=True)
        if line.startswith("index") or line.startswith("diff"):
            continue
        if line.startswith("---"):
            filename_before = line.split(" ", 1)[1][1:]
            continue

        if line.startswith("+++"):
            filename_after = line.split(" ", 1)[1][1:]

            if filename_before == filename_after:
                accumulator.append(f"<ide><path>{filename_before}")
            else:
                accumulator.append(f"<add><path>{filename_after}")
                accumulator.append(f"<del><path>{filename_before}")
            continue

        line = re.sub("@@[^@@]*@@", "", line)
        if len(line) == 0:
            continue

        if line[0] == "+":
            line = line.replace("+", "<add>", 1)
        elif line[0] == "-":
            line = line.replace("-", "<del>", 1)
        else:
            line = f"<ide>{line}"

        accumulator.append(line)

    return '\n'.join(accumulator)

def predict(patch, max_length, min_length, num_beams, prediction_count):
    print(".", end="", flush=True)
    input_text = parse_files(patch)

    tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01", low_cpu_mem_usage=True)
    print(".", end="", flush=True)
    model = T5ForConditionalGeneration.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01", low_cpu_mem_usage=True)
    print(".", end="", flush=True)

    with torch.no_grad():
        input_ids = tokenizer(
            input_text,
            truncation=True,
            padding=True,
            return_tensors="pt",
        ).input_ids
        print(".", end="", flush=True)
        outputs = model.generate(
            input_ids,
            max_length=max_length,
            min_length=min_length,
            num_beams=num_beams,
            num_return_sequences=prediction_count,
        )
        print(".", end="", flush=True)

    result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return result[0]

if __name__ == "__main__":
    diff = subprocess.run(['git', 'diff', '--cached'], capture_output=True).stdout.decode('utf-8')

    max_message = 40
    min_message = 5
    num_beams = 10
    num_predictions = 1

    msg = predict(diff, max_message, min_message, num_beams, num_predictions)

    with open(sys.argv[1], 'r+') as f:
        content = f.read()
        f.seek(0)
        f.write(msg + '\n' + content)
        f.close()

    print("Done!\n")

На процессоре он работает быстро, но загрузка модели занимает много времени. В любом случае 3 с – это нормально.
Вот и все. Это работает. По крайней мере, у меня.

+1
0
+1
0
+1
0
+1
0
+1
0

Ответить

Ваш адрес email не будет опубликован. Обязательные поля помечены *