Я обучался LLM с помощью GSM8K. Когда я хочу это протестировать:
образец данных:
{'question': "Every day, Wendi feeds each of her chickens three cups of mixed chicken feed, containing seeds, mealworms and vegetables to help keep them healthy. She gives the chickens their feed in three separate meals. In the morning, she gives her flock of chickens 15 cups of feed. In the afternoon, she gives her chickens another 25 cups of feed. How many cups of feed does she need to give her chickens in the final meal of the day if the size of Wendi's flock is 20 chickens?"
'answer(label)': 'If each chicken eats 3 cups of feed per day, then for 20 chickens they would need 3*20=60 cups of feed per day.\nIf she feeds the flock 15 cups of feed in the morning, and 25 cups in the afternoon, then the final meal would require 60-15-25=20 cups of chicken feed.\n#### 20'
Пока генерация:
Every day, Wendi feeds each of her chickens three cups of mixed chicken feed, containing seeds, mealworms and vegetables to help keep them healthy. She gives the chickens their feed in three separate meals. In the morning, she gives her flock of chickens 15 cups of feed. In the afternoon, she gives her chickens another 25 cups of feed. How many cups of feed does she need to give her chickens in the final meal of the day if the size of Wendi's flock is 20 chickens? To determine how many cups of feed Wendi needs for the final meal of the day, we start by calculating the total amount of feed required for all the chickens.
First, we calculate the amount of feed given in a single day:
\[
15 \text{ cups (morning)} + 25 \text{ cups (afternoon)} = 40 \text{ cups}
\]
Next, we need to account for the number of chickens:
\[
20 \text{ chickens}
\]
Хотя ответ правильный, как мне найти метку в выходных данных?
Или как сделать так, чтобы выходные данные совпадали с меткой по формату?
Код:
import re
def extract_num(text):
# Regex pattern to find the number following '####'
pattern = r'####\s*(\d+)'
# Using re.search to find the first match
match = re.search(pattern, text)
if match:
result = match.group(1)
print(result)
else:
print(text)
result = ""
try:
return int(result.replace(",", ""))
except:
print(f"'{result}' can't be converted")
return 0
all = 0
correct = 0
samples = test_ds.select(range(10))
# samples = test_ds[:50]
# print(samples[0]['question'])
for example in samples:
# print(f"Question: {example['question']}")
# print(f"Answer: {example['answer']}")
input_text = f"Question: {example['question']}\nPlease provide the answer in the format '#### + number' shortly without repeating the quetions."
inputs = tokenizer(example['question'], return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {key: value.to(device) for key, value in inputs.items()}
outputs = model.generate(inputs['input_ids'], max_length=512)
pred_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
gt = extract_num(example["answer"])
pred = extract_num(pred_text)
correct += int(gt == pred)
all += 1
# if all % 100 == 0:
print(f"{all} Acc: {correct/all:.2f}")
# t.set_description(f"Accuracy: {correct/all*100:.2f}%")
print("Acc:", correct/all)
Подробнее здесь: https://stackoverflow.com/questions/792 ... ground-tru
Как при тестировании LLM в наборе данных GSM8K сравнивать результаты и маркировать (основную истину) как в традиционном ⇐ Python
-
- Похожие темы
- Ответы
- Просмотры
- Последнее сообщение