Лама использует всю оперативную память, вызывая смерть ядраPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Лама использует всю оперативную память, вызывая смерть ядра

Сообщение Anonymous »

Код: Выделить всё

import os
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
pipeline,
logging,
)
from peft import LoraConfig
from trl import SFTTrainer

dataset = load_dataset("csv", data_files="dataset/data.csv")

base_model = "meta-llama/Llama-3.2-1B"
compute_dtype = getattr(torch, "float16")

# Configure memory-efficient quantization
compute_dtype = getattr(torch, "float16")
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=True,  # Enable double quantization
)

model = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=quant_config,
device_map="auto",  # Let transformers handle device mapping
torch_dtype=torch.float16,  # Use fp16 for model weights
low_cpu_mem_usage=True,    # Enable memory optimization
)

torch.cuda.empty_cache()
model.config.use_cache = False
model.config.pretraining_tp = 1

# Configure PEFT using LoRA for efficient fine-tuning of the model.
peft_params = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear",
)

training_params = TrainingArguments(
output_dir="./results",
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
optim="paged_adamw_8bit",
save_steps=50,
logging_steps=50,
learning_rate=2e-4,
weight_decay=0.001,
fp16=True,
bf16=False,
max_grad_norm=0.3,
max_steps=-1,
warmup_ratio=0.03,
group_by_length=True,
lr_scheduler_type="constant",
report_to="tensorboard",
gradient_checkpointing=True,
)

tokenizer = AutoTokenizer.from_pretrained(
base_model,
padding_side="right",
truncation_side="right",
)
tokenizer.pad_token = tokenizer.eos_token

trainer = SFTTrainer(
model=model,
train_dataset=dataset['train'],
peft_config=peft_params,
dataset_text_field="input_text",
max_seq_length=512,
tokenizer=tokenizer,
args=training_params,
packing=False,
)

trainer.train()
Я использую приведенный выше блок кода для точной настройки параметров Llama 1-B с помощью моего компьютера с 128 ГБ ОЗУ и графическим процессором 4090. Соответственно, ПК соответствует всем требованиям модели, но при выполнении строки Trainer = SFTTrainer(....) оперативная память заполняется и ядро ​​умирает, что неожиданно. Размер набора данных составляет всего 10 ГБ с 7400 строками данных. Буду рад, если кто-нибудь поможет мне решить эту проблему. Отслеживание ошибок выглядит следующим образом (используется вся память 128 ГБ, терминал выключается) -

Код: Выделить всё

Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
warnings.warn(message, FutureWarning)
/home/.../python3.10/site-packages/trl/trainer/sft_trainer.py:300: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
warnings.warn(
/home/.../python3.10/site-packages/trl/trainer/sft_trainer.py:328: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
warnings.warn(
Map:  14%|██████████████████▍                                                                                                                    | 1000/7346 [02:59

Подробнее здесь: [url]https://stackoverflow.com/questions/79200200/llama-using-up-all-ram-storage-causing-kernel-to-die[/url]
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»