Встроенная оболочка mlx_lm.evaluate дает неправильные результаты для некоторых задач и аварийно завершает работу при генеративных задачах, поэтому я необходимо реализовать интерфейс самому.
Базовый класс lm_eval.api.model.LM требует метода loglikelihood(), который по заданному списку пар строк (контекст, продолжение) возвращает общую логарифмическую вероятность токенов продолжения, зависящих от контекста. Вот моя попытка:
Код: Выделить всё
import mlx.core as mx
import mlx.nn as nn
from mlx_lm import load
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
@register_model("mlx_native")
class MLXLM(LM):
def __init__(self, model_path: str, **kwargs):
super().__init__()
self.model, self.tokenizer = load(model_path)
def loglikelihood(self, requests):
results = []
for context, continuation in [req.args for req in requests]:
# Tokenize context and continuation separately
ctx_tokens = self.tokenizer.encode(context)
cont_tokens = self.tokenizer.encode(continuation, add_special_tokens=False)
all_tokens = ctx_tokens + cont_tokens
input_ids = mx.array(all_tokens[:-1])[None] # (1, seq_len-1)
# Forward pass to get logits
logits = self.model(input_ids) # (1, seq_len-1, vocab_size)
# Log-softmax over vocab dimension
log_probs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
# Gather log-probs for actual next tokens
target_ids = mx.array(all_tokens[1:])
# Extract log-prob at each position for the target token
token_log_probs = log_probs[0, mx.arange(len(target_ids)), target_ids]
# Sum only over continuation positions
cont_start = len(ctx_tokens) - 1
cont_log_prob = mx.sum(token_log_probs[cont_start:]).item()
# Check if greedy prediction matches continuation
greedy = mx.argmax(logits[0, cont_start:], axis=-1)
is_greedy = bool(mx.all(greedy == mx.array(cont_tokens)).item())
results.append((cont_log_prob, is_greedy))
return results
def loglikelihood_rolling(self, requests):
raise NotImplementedError # TODO
def generate_until(self, requests):
raise NotImplementedError # TODO
Код: Выделить всё
lm_eval --model hf --device mpsЯ подозреваю, что проблема заключается в том, как я вычисляю лог-вероятности: в частности, возвращает ли self.model(input_ids) логиты в той же форме/соглашении, что и model.forward() HuggingFace, или индексация на единицу между токенами контекста и продолжения неверна для поведения токенизатора MLX.
Что я проверил:
- Выходные данные токенизатора mlx_lm и HuggingFace совпадают для одних и тех же входных строк.
- Модель загружается без ошибок и генерирует связный текст с помощью mlx_lm.generate()
- Бэкэнд lm-evaluation-harness MPS обеспечивает ожидаемую точность для неквантованной модели
Среда: Python 3.11, mlx 0.22.0, mlx-lm 0.21.0, lm-eval 0.4.10, macOS 15.3, M4 Макс. 128 ГБ
Подробнее здесь: https://stackoverflow.com/questions/798 ... s-using-ml
Мобильная версия