Распределение энтропии Shap перекрывается в модели бинарной классификацииPython

Программы на Python
Ответить
Anonymous
 Распределение энтропии Shap перекрывается в модели бинарной классификации

Сообщение Anonymous »

Итак, я обучал модель классификации двоичных изображений, создавая последовательную архитектуру CNN, и после обучения модели точность теста: 0,98 и потеря теста: 0,08, поэтому теперь я хотел построить гистограмму распределения энтропии по частоте между классами, используя shap, но даже после нескольких попыток перекрытие между классами все равно происходит, в идеальном случае имеется разрыв между распределением энтропии разных классов. Как это исправить?

Код: Выделить всё

import numpy as np
import tensorflow as tf
import shap
import matplotlib.pyplot as plt
import seaborn as sns

def calculate_shap_entropy(shap_values):
"""
Calculate entropy given SHAP values.

Parameters:
shap_values (numpy.ndarray): SHAP values for samples.

Returns:
numpy.ndarray: Entropy values for each sample.
"""
abs_shap_values = np.abs(shap_values)
flattened_shap_values = abs_shap_values.reshape(len(shap_values), -1)
normalized_shap_values = flattened_shap_values / np.sum(flattened_shap_values, axis=1, keepdims=True)
entropy_values = -np.sum(normalized_shap_values * np.log(normalized_shap_values + 1e-12), axis=1)
return entropy_values

def plot_entropy_distribution(entropy_class_0, entropy_class_1, class_labels=('Not Fractured', 'Fractured')):
"""
Plot entropy distribution histograms and KDE for the two classes.

Parameters:
entropy_class_0 (numpy.ndarray): Entropy values for class 0.
entropy_class_1 (numpy.ndarray): Entropy values for class 1.
class_labels (tuple): Labels for the two classes.
"""
plt.figure(figsize=(14, 7))

# Histogram with KDE
sns.histplot(entropy_class_0, bins=30, kde=True, color='blue', label=f'{class_labels[0]}', stat="density", alpha=0.6)
sns.histplot(entropy_class_1, bins=30, kde=True, color='red', label=f'{class_labels[1]}', stat="density", alpha=0.6)

# Add vertical lines for mean and std
mean_class_0, std_class_0 = np.mean(entropy_class_0), np.std(entropy_class_0)
mean_class_1, std_class_1 = np.mean(entropy_class_1), np.std(entropy_class_1)

plt.axvline(mean_class_0, color='blue', linestyle='--', label=f'{class_labels[0]} Mean: {mean_class_0:.2f}')
plt.axvline(mean_class_1, color='red', linestyle='--', label=f'{class_labels[1]} Mean: {mean_class_1:.2f}')

plt.axvline(mean_class_0 - std_class_0, color='blue', linestyle=':', label=f'{class_labels[0]} ±1 Std')
plt.axvline(mean_class_0 + std_class_0, color='blue', linestyle=':')
plt.axvline(mean_class_1 - std_class_1, color='red', linestyle=':', label=f'{class_labels[1]} ±1 Std')
plt.axvline(mean_class_1 + std_class_1, color='red', linestyle=':')

# Plot details
plt.title('Entropy Distribution by Class', fontsize=16)
plt.xlabel('Entropy', fontsize=14)
plt.ylabel('Density', fontsize=14)
plt.legend(fontsize=12)
plt.grid(axis='y', linestyle='--', alpha=0.6)
plt.show()

# Example usage
X_train, _ = next(iter(training_set))  # Get a batch of data
X_train = X_train / 255.0
X_test, _ = next(iter(test_set))  # Get a batch of data
X_test = X_test / 255.0  # Normalize pixel values

# Create SHAP KernelExplainer with a subset of data
explainer = shap.DeepExplainer(model,X_train)  # Use first 10 samples as background

# Compute SHAP values for the same subset
shap_values = explainer.shap_values(X_test)

# Calculate entropy for each class
entropy_class_0 = calculate_shap_entropy(shap_values[0])  # SHAP values for class 0
entropy_class_1 = calculate_shap_entropy(shap_values[1])  # SHAP values for class 1

# Plot entropy distributions
plot_entropy_distribution(entropy_class_0, entropy_class_1, class_labels=('Not Fractured', 'Fractured'))
есть ли ошибка в формуле расчета энтропии
перекрывающееся изображение гистограммы нажмите здесь
изображение модели архитектуры нажмите здесь

Подробнее здесь: https://stackoverflow.com/questions/793 ... tion-model
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»