Различные размеры подзаголовков в matplotlib/seabornPython

Программы на Python
Ответить
Anonymous
 Различные размеры подзаголовков в matplotlib/seaborn

Сообщение Anonymous »

Я хотел бы создать фигуру с графиками 3x3 (тепловыми картами), ширина одинакова для всех 9 графиков. Однако высота должна быть разной для каждого графика, так как высота должна быть равна number_of_cells_of Heatplot*cell_height. Теперь я могу сделать так, чтобы каждая строка на моем рисунке имела разную высоту, но не могу изменить высоту подграфиков внутри строки. Я приложил скриншот, чтобы было понятнее, мне бы хотелось, чтобы ячейки всех тепловых графиков имели одинаковую высоту, тогда это будет означать, например, подграфик в первой строке и втором столбце будет всего одной строкой.
Это мой код на данный момент

Код: Выделить всё

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import gridspec

groups = ['g1', 'g2', 'g3', 'g4', 'g5', 'g6', 'g7', 'g8', 'g9']

dummy_data = {
'names':['name_1', 'name_2', 'name_3', 'name_4', 'name_5', 'name_6', 'name_7', 'name_8', 'name_9', 'name_10', 'name_11', 'name_12', 'name_13', 'name_14', 'name_15', 'name_16', 'name_17', 'name_18', 'name_19', 'name_20']
,
'group': [
'g1', 'g1', 'g1', 'g1', 'g2', 'g3', 'g3', 'g3', 'g4', 'g4',
'g5', 'g5', 'g6', 'g6', 'g7', 'g7', 'g8', 'g9', 'g9', 'g9'
],
'col1': [
-30, 10, 5, -20, 15, 25, -15, 0, -10, 20,
-5, 15, 35, -25, 10, 5, -15, 30, 10, -5
],
'col2': [
-50, 20, 30, -40, 25, 45, -10, 5, -15, 35,
-10, 20, 40, -20, 15, 10, -20, 45, 20, -10
],
'col3': [
-45, 15, 35, -35, 20, 40, -5, 10, -10, 30,
-15, 25, 45, -15, 12, 7, -18, 40, 25, -8
],
'col4': [
0.05, 0.08, 0.07, 0.06, 0.09, 0.11, 0.05, 0.08, 0.07, 0.10,
0.04, 0.09, 0.12, 0.06, 0.07, 0.08, 0.06, 0.11, 0.05, 0.09
],
'col5': [
0.12, 0.07, 0.09, 0.04, 0.08, 0.10, 0.06, 0.09, 0.05, 0.11,
0.03, 0.10, 0.13, 0.07, 0.06, 0.07, 0.05, 0.10, 0.04, 0.08
],
'col6': [
20, 25, 15, 30, 40, 10, 35, 5, 45, 25,
10, 20, 50, 15, 30, 25, 40, 15, 10, 35
],
'col7': [
45, 15, 25, 10, 30, 40, 20, 35, 5, 20,
30, 15, 40, 20, 10, 15, 50, 30, 5, 25
],
'col8': [
5, 2, 8, 3, 7, 6, 4, 9, 1, 8,
6, 3, 10, 2, 7, 4, 9, 5, 2, 8
],
'col9': [
6, 3, 9, 4, 8, 7, 5, 10, 2, 9,
7, 4, 10, 3, 8, 5, 9, 6, 3, 9
]
}

pivot_df = pd.DataFrame(dummy_data)

# Define column bounds with new names
column_bounds = {
'col1': (-35, 35),
'col2': (-55, 55),
'col3': (-50, 50),
'col4': (0.03, 0.15),
'col5': (0.03, 0.15),
'col6': (0, 50),
'col7': (0, 50),
'col8': (0, 10),
'col9': (0, 10),
}

# Calculate row counts for each group
row_counts = [len(pivot_df.loc[pivot_df['group'] == group]) for group in groups]

# Calculate height ratios based on the row count for each group to ensure each cell has the same height
cell_height = 0.5  # Set a fixed cell height
height_ratios = [count * cell_height for count in row_counts]

# Create the figure and GridSpec with proportional height ratios for each subplot
fig = plt.figure(figsize=(32, 21))
gs = gridspec.GridSpec(3, 3, height_ratios=[max(height_ratios[:3]), max(height_ratios[3:6]), max(height_ratios[6:])])

# Add subplots to the grid
axes = [fig.add_subplot(gs[i // 3, i % 3]) for i in range(9)]

for i, group in enumerate(groups):
# Get the data for the current group
group_data = pivot_df.loc[pivot_df['group'] == group].set_index('names').drop(columns={'group'})

# Create annotation array
annotations = group_data.copy()
for col in ['col6', 'col7', 'col8', 'col9']:
annotations[col] = annotations[col].apply(lambda x: f"{x:.2f}M")
for col in ['col1', 'col2', 'col3', 'col4', 'col5']:
annotations[col] = annotations[col].apply(lambda x: f"{x:.2f}")

# Plot each column individually with separate bounds
for j, col in enumerate(group_data.columns):
# Create a mask for other columns to plot each individually
mask = np.ones(group_data.shape)
mask[:, j] = 0  # Mask all except the current column

sns.heatmap(
group_data,
annot=annotations,
fmt='',
cmap="coolwarm",
cbar=False,
vmin=column_bounds[col][0],
vmax=column_bounds[col][1],
mask=mask,
ax=axes[i]
)

axes[i].set_title(group)
axes[i].set_ylabel('')

# Rotate y-tick labels to be horizontal if there are any labels
yticklabels = group_data.index.tolist()
axes[i].set_yticks(np.arange(len(group_data)))  # Set tick positions based on number of rows
axes[i].set_yticklabels(yticklabels, rotation=0, ha='right', fontsize=8)

if i <  6:  # Only for the top two rows (0–5), remove x labels and ticks
axes[i].set_xlabel('')
axes[i].set_xticks([])

fig.tight_layout()
plt.subplots_adjust(wspace=0.4, hspace=0.2)

plt.show()
Изображение
-

Подробнее здесь: https://stackoverflow.com/questions/791 ... ib-seaborn
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»