Как мне визуализировать скрытое представление, создаваемое стабильной диффузией VAE?Python

Программы на Python
Ответить
Anonymous
 Как мне визуализировать скрытое представление, создаваемое стабильной диффузией VAE?

Сообщение Anonymous »

Я пытаюсь визуализировать скрытое представление, создаваемое VAE внутри конвейера стабильной диффузии

Код: Выделить всё

from diffusers import StableDiffusionPipeline
import torch

# A CUDA ordinal is simply the integer ID of a GPU in a system that has one or more GPUs.
def get_device(cuda_ordinal=None):
if torch.cuda.is_available():
return torch.device("cuda", cuda_ordinal)
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")

device=get_device()

pipe=StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16"
).to(device)
Я написал помощник load_image

Код: Выделить всё

from PIL import Image
from io import BytesIO
import requests

def load_image(url, size=None, return_tensor=False):
if url.startswith("http"):
response = requests.get(url)
img = Image.open(BytesIO(response.content))
else:
img = Image.open(url)
if size is not None:
img = img.resize(size)
if return_tensor:
return TF.to_tensor(img)
return img
Я успешно загрузил и нарисовал график с помощью библиотеки matplotlib

Код: Выделить всё

import matplotlib.pyplot as plt

im=load_image("https://media.gettyimages.com/id/2244914756/photo/a-small-island-shaped-like-a-straw-hat-by-the-blue-lake.jpg?s=2048x2048&w=gi&k=20&c=wzAXUuF06cnKTYrV7Xs8DG9jHPvrAf-1tW2Vs53VXOg=",size=(512,512))

plt.imshow(im)
plt.axis('off')  # remove axes
plt.show()
VAE сжимает это изображение в четырехканальное скрытое представление, и мне нравится визуализировать каждый канал скрытого представления VAE:

Код: Выделить всё

from torchvision import transforms
with torch.inference_mode():
tensor_im=transforms.ToTensor()(im).unsqueeze(0).to(device)*2-1
tensor_im=tensor_im.half()
latent=pipe.vae.encode(tensor_im)
latents=latent.latent_dist.sample()
latents=latents*0.18215

latents.shape
#torch.Size([1, 4, 64, 64])
это вспомогательная функция для построения каждого представления

Код: Выделить всё

import matplotlib.pyplot as plt

def show_latent_channels(latents):
lat = latents[0]  # [4, 64, 64]
num_ch = lat.shape[0]

fig, axes = plt.subplots(1, num_ch, figsize=(12,3))
for i in range(num_ch):
ch = lat[i]
# normalize channel to [0,1] for visualization
ch = (ch - ch.min()) / (ch.max() - ch.min())
axes[i].imshow(ch.cpu(), cmap="viridis")
axes[i].set_title(f"Channel {i}")
axes[i].axis("off")

plt.show()
когда я вызываю show_latent_channels(latents) я получаю это
Изображение

Моя цель — просто «увидеть», что VAE хранит внутри скрытого пространства, но matplotlib показывает пустой график

Подробнее здесь: https://stackoverflow.com/questions/798 ... ffusion-va
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»