Как сохранить одну модель случайного леса с перекрестной проверкой?Python

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Как сохранить одну модель случайного леса с перекрестной проверкой?

Сообщение Anonymous »

Я использую 10-кратную перекрестную проверку, пытаясь предсказать двоичные метки (Y) на основе входных данных для внедрения (X).
Я хочу сохранить одну из моделей (возможно, ту, у которой самый высокий ROC AUC). . Я не знаю, как это сделать, поскольку значения ROC AUC не сохраняются и я не знаю, как их получить.
X = np.array([np.array(x) for x in df['embeddings'].values])
y = df['label'].values
groups = df['chromosome'].values
group_kfold = GroupKFold(n_splits=n_folds)

Инициализировать фигуру для построения
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

all_fpr = []
all_tpr = []
all_accuracy = []
all_pr_auc = []

Perform cross-validation and plot ROC and PR curves for each fold
for i, (train_idx, val_idx) in enumerate(group_kfold.split(X, y, groups)):
X_train_fold, X_val_fold = X[train_idx], X[val_idx]
y_train_fold, y_val_fold = y[train_idx], y[val_idx]

# Initialize classifier
rf_classifier = RandomForestClassifier(n_estimators=n_trees, random_state=42, max_depth=max_depth, n_jobs=-1)

# Train the classifier on this fold
rf_classifier.fit(X_train_fold, y_train_fold)

# Make predictions on the validation set
y_pred_proba = rf_classifier.predict_proba(X_val_fold)[:, 1]

# Calculate ROC curve
fpr, tpr, _ = roc_curve(y_val_fold, y_pred_proba)

all_fpr.append(fpr)
all_tpr.append(tpr)

# Calculate AUC
roc_auc = auc(fpr, tpr)

# Plot ROC curve for this fold
axes[0].plot(fpr, tpr, lw=1, alpha=0.7, label=f'ROC Fold {i+1} (AUC = {roc_auc:.2f})')

# Calculate precision-recall curve
precision, recall, _ = precision_recall_curve(y_val_fold, y_pred_proba)

# Calculate PR AUC
pr_auc = auc(recall, precision)
all_pr_auc.append(pr_auc)

# Plot PR curve for this fold
axes[1].plot(recall, precision, lw=1, alpha=0.7, label=f'PR Curve Fold {i+1} (AUC = {pr_auc:.2f})')

# Calculate accuracy
accuracy = accuracy_score(y_val_fold, rf_classifier.predict(X_val_fold))
all_accuracy.append(accuracy)

# Initialize empty arrays to store interpolated TPR values
interpolated_tpr = []

# Define common set of thresholds
mean_fpr = np.linspace(0, 1, 100)

# Interpolate TPR values for each fold to the common set of thresholds
for fpr, tpr in zip(all_fpr, all_tpr):
interpolated_tpr.append(np.interp(mean_fpr, fpr, tpr))

# Calculate the mean and standard deviation of interpolated TPR values
mean_tpr = np.mean(interpolated_tpr, axis=0)
std_tpr = np.std(interpolated_tpr, axis=0)

# Plot the mean ROC curve with shaded area representing the standard deviation
axes[0].plot(mean_fpr, mean_tpr, color='black', linestyle='--', lw=2, label=f'Average ROC curve ({np.round(auc(mean_fpr, mean_tpr), 2)})')
axes[0].fill_between(mean_fpr, mean_tpr - std_tpr, mean_tpr + std_tpr, color='grey', alpha=0.2)

# Plot ROC for random classifier
axes[0].plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', alpha=0.8)


Подробнее здесь: https://stackoverflow.com/questions/786 ... validation
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение
  • Как сохранить одну модель случайного леса с перекрестной проверкой
    Anonymous » » в форуме Python
    0 Ответы
    20 Просмотры
    Последнее сообщение Anonymous
  • Как сохранить одну модель случайного леса с перекрестной проверкой?
    Anonymous » » в форуме Python
    0 Ответы
    13 Просмотры
    Последнее сообщение Anonymous
  • Обученная модель случайного леса из Python в Matlab
    Гость » » в форуме Python
    0 Ответы
    32 Просмотры
    Последнее сообщение Гость
  • Матрица путаницы с перекрестной проверкой
    Anonymous » » в форуме Python
    0 Ответы
    30 Просмотры
    Последнее сообщение Anonymous
  • Для моего анализа данных о ценах на жилье с использованием ГБ, дерева и случайного леса моя MSE слишком высока.
    Anonymous » » в форуме Python
    0 Ответы
    31 Просмотры
    Последнее сообщение Anonymous

Вернуться в «Python»