Как реализовать loglikelihood() для lm-evaluation-harness на основе MLX с использованием mlx_lm?Python

Программы на Python
Ответить
Anonymous
 Как реализовать loglikelihood() для lm-evaluation-harness на основе MLX с использованием mlx_lm?

Сообщение Anonymous »

Я пишу пользовательскую серверную часть модели lm-evaluation-harness, которая выполняет вывод через платформу Apple MLX через mlx_lm, поэтому я могу тестировать квантованные LLM непосредственно на Apple Silicon без маршрутизации через PyTorch MPS.
Встроенная оболочка mlx_lm.evaluate дает неправильные результаты для некоторых задач и аварийно завершает работу при генеративных задачах, поэтому я необходимо реализовать интерфейс самому.
Базовый класс lm_eval.api.model.LM требует метода loglikelihood(), который по заданному списку пар строк (контекст, продолжение) возвращает общую логарифмическую вероятность токенов продолжения, зависящих от контекста. Вот моя попытка:

Код: Выделить всё

import mlx.core as mx
import mlx.nn as nn
from mlx_lm import load
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model

@register_model("mlx_native")
class MLXLM(LM):
def __init__(self, model_path: str, **kwargs):
super().__init__()
self.model, self.tokenizer = load(model_path)

def loglikelihood(self, requests):
results = []
for context, continuation in [req.args for req in requests]:
# Tokenize context and continuation separately
ctx_tokens = self.tokenizer.encode(context)
cont_tokens = self.tokenizer.encode(continuation, add_special_tokens=False)
all_tokens = ctx_tokens + cont_tokens

input_ids = mx.array(all_tokens[:-1])[None]  # (1, seq_len-1)

# Forward pass to get logits
logits = self.model(input_ids)  # (1, seq_len-1, vocab_size)

# Log-softmax over vocab dimension
log_probs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)

# Gather log-probs for actual next tokens
target_ids = mx.array(all_tokens[1:])
# Extract log-prob at each position for the target token
token_log_probs = log_probs[0, mx.arange(len(target_ids)), target_ids]

# Sum only over continuation positions
cont_start = len(ctx_tokens) - 1
cont_log_prob = mx.sum(token_log_probs[cont_start:]).item()

# Check if greedy prediction matches continuation
greedy = mx.argmax(logits[0, cont_start:], axis=-1)
is_greedy = bool(mx.all(greedy == mx.array(cont_tokens)).item())

results.append((cont_log_prob, is_greedy))
return results

def loglikelihood_rolling(self, requests):
raise NotImplementedError  # TODO

def generate_until(self, requests):
raise NotImplementedError  # TODO
Проблема: Когда я запускаю это на arc_easy с mlx-community/Llama-3.1-8B-Instruct-4bit, я получаю результаты по точности примерно на 3–5 процентных пунктов ниже, чем в справочнике HuggingFace Transformers (

Код: Выделить всё

lm_eval --model hf --device mps
) с теми же весами базовой модели.
Я подозреваю, что проблема заключается в том, как я вычисляю лог-вероятности: в частности, возвращает ли self.model(input_ids) логиты в той же форме/соглашении, что и model.forward() HuggingFace, или индексация на единицу между токенами контекста и продолжения неверна для поведения токенизатора MLX.
Что я проверил:
  • Выходные данные токенизатора mlx_lm и HuggingFace совпадают для одних и тех же входных строк.
  • Модель загружается без ошибок и генерирует связный текст с помощью mlx_lm.generate()
  • Бэкэнд lm-evaluation-harness MPS обеспечивает ожидаемую точность для неквантованной модели
Что мне не удалось определить: возвращает ли model(input_ids) в mlx_lm логиты, выровненные так же, как HuggingFace (т. е. logits) token[i+1]), поскольку источник mlx_lm использует функциюgenerate_step(), которая обрабатывает это внутри, а необработанное поведение __call__ не документировано для сценариев использования оценки.
Среда: Python 3.11, mlx 0.22.0, mlx-lm 0.21.0, lm-eval 0.4.10, macOS 15.3, M4 Макс. 128 ГБ

Подробнее здесь: https://stackoverflow.com/questions/798 ... s-using-ml
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»