Anonymous
Расхождение в параметрах модели (70,13 млн в статье против >100 млн в реализации)
Сообщение
Anonymous » 06 ноя 2025, 13:19
В настоящее время я воспроизвожу архитектуру
D-TrAttUnet из статьи:
D-TrAttUnet: На пути к гибридной архитектуре CNN-трансформатора для общей и тонкой сегментации медицинских изображений
https://doi.org/10.1016/j.compbiomed.2024.108590
Я следовал официальной реализации, представленной в их репозитории GitHub:
https://github.com/faresbougourzi/D-TrAttUnet
Согласно статье, общее количество параметров модели составляет
70,13 миллиона :
Однако, когда я реализую архитектуру (используя тот же код из Architecture.py), я последовательно получаю
более 100 миллионов параметров (около 104–132M, в зависимости от среды).
Я подозреваю, что несоответствие может быть связано с различиями версий используемых библиотек
MONAI или
PyTorch , поскольку статья была опубликована в 2024 году.
Я пробовал несколько версий MONAI как из
2023 , так и из
2024 . (например, monai==1.1.0, 1.2.0, 1.3.0) и версии PyTorch, такие как 2.0.0, 2.1.0 и 2.2.0, но количество параметров остается значительно выше, чем сообщается.
Код: Выделить всё
!pip install --quiet monai
import torch
import torch.nn as nn
from itertools import repeat
from typing import Union, Sequence
from monai.networks.blocks.transformerblock import TransformerBlock
class PatchEmbeddingBlock(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
super().__init__()
img_size = tuple(repeat(img_size, 2))
patch_size = tuple(repeat(patch_size, 2))
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
##############################################################################
# Transformer Path
class ViT(nn.Module):
def __init__(
self,
in_channels: int,
img_size: Union[Sequence[int], int],
patch_size: Union[Sequence[int], int],
hidden_size: int = 768,
mlp_dim: int = 3072,
num_layers: int = 12,
num_heads: int = 12,
dropout_rate: float = 0.0,
):
super().__init__()
self.patch_embedding = PatchEmbeddingBlock()
self.blocks = nn.ModuleList(
[TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)]
)
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x):
x = self.patch_embedding(x)
hidden_states_out = []
for blk in self.blocks:
x = blk(x)
hidden_states_out.append(x)
x = self.norm(x)
return x, hidden_states_out
##############################################################################
# Attention Block
class Attention_block(nn.Module):
def __init__(self,F_g,F_l,F_int):
super(Attention_block,self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1+x1)
psi = self.psi(psi)
return x*psi
##############################################################################
# ResBlock
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DoubleConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
self.skip = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True))
def forward(self, x):
return self.conv(x) + self.skip(x)
##############################################################################
# UpResBlock
class UPDoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, num_layer):
super(UPDoubleConv, self).__init__()
self.deconv = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
DoubleConv(in_channels, out_channels)
)
self.blocks = nn.ModuleList(
[
nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
DoubleConv(out_channels, out_channels)
)
for i in range(num_layer)
]
)
def forward(self, x):
x = self.deconv(x)
for blk in self.blocks:
x = blk(x)
return x
##############################################################################
# D-TrAttUnet Architecture
class DTrAttUnet(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
img_size: Union[Sequence[int], int],
feature_size: int = 16,
hidden_size: int = 768,
mlp_dim: int = 3072,
num_heads: int = 12):
super().__init__()
self.hidden_size = hidden_size
self.classification = False
self.pool = nn.MaxPool2d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.vit = ViT(in_channels=3,img_size=224,patch_size=16)
nb_filter = [32, 64, 128, 256, 512]
self.encoder1 = DoubleConv(in_channels, nb_filter[0])
self.conv1 = DoubleConv(in_channels, nb_filter[0])
self.conv2 = DoubleConv(nb_filter[1], nb_filter[1])
self.conv3 = DoubleConv(nb_filter[2], nb_filter[2])
self.conv4 = DoubleConv(nb_filter[3], nb_filter[3])
self.conv5 = DoubleConv(nb_filter[4], nb_filter[4])
self.encoder2 = UPDoubleConv(
in_channels = hidden_size, out_channels = nb_filter[0], num_layer = 2
)
self.encoder3 = UPDoubleConv(
in_channels = hidden_size, out_channels = nb_filter[1], num_layer = 1
)
self.encoder4 = UPDoubleConv(
in_channels = hidden_size, out_channels = nb_filter[2], num_layer = 0
)
self.encoder5 = DoubleConv(hidden_size, nb_filter[3])
# Decoder1
self.Att4 = Attention_block(F_g= nb_filter[4], F_l=nb_filter[3], F_int=nb_filter[3])
self.Att3 = Attention_block(F_g= nb_filter[3], F_l=nb_filter[2], F_int=nb_filter[2])
self.Att2 = Attention_block(F_g= nb_filter[2], F_l=nb_filter[1], F_int=nb_filter[1])
self.Att1 = Attention_block(F_g= nb_filter[1], F_l=nb_filter[0], F_int=nb_filter[0])
self.deconv1 = DoubleConv(nb_filter[3]+nb_filter[4], nb_filter[3])
self.deconv2 = DoubleConv(nb_filter[2]+nb_filter[3], nb_filter[2])
self.deconv3 = DoubleConv(nb_filter[1]+nb_filter[2], nb_filter[1])
self.deconv4 = DoubleConv(nb_filter[0]+nb_filter[1], nb_filter[0])
self.final = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1)
# Decoder2
self.Att41 = Attention_block(F_g= nb_filter[4], F_l=nb_filter[3], F_int=nb_filter[3])
self.Att31 = Attention_block(F_g= nb_filter[3], F_l=nb_filter[2], F_int=nb_filter[2])
self.Att21 = Attention_block(F_g= nb_filter[2], F_l=nb_filter[1], F_int=nb_filter[1])
self.Att11 = Attention_block(F_g= nb_filter[1], F_l=nb_filter[0], F_int=nb_filter[0])
self.deconv11 = DoubleConv(nb_filter[3]+nb_filter[4], nb_filter[3])
self.deconv21 = DoubleConv(nb_filter[2]+nb_filter[3], nb_filter[2])
self.deconv31 = DoubleConv(nb_filter[1]+nb_filter[2], nb_filter[1])
self.deconv41 = DoubleConv(nb_filter[0]+nb_filter[1], nb_filter[0])
self.final1 = nn.Conv2d(nb_filter[0], 1, kernel_size=1)
def proj_feat(self, x, hidden_size, feat_size):
x = x.view(x.size(0), feat_size, feat_size, hidden_size)
x = x.permute(0, 3, 1, 2).contiguous()
return x
def forward(self, x_in):
v4, hidden_states_out = self.vit(x_in)
feat_size = 14
hidden_size = 768
x1 = self.conv1(x_in)
v1 = hidden_states_out[3]
enc2 = self.encoder2(self.proj_feat(v1, hidden_size ,feat_size))
x2 = self.conv2(torch.cat([self.pool(x1), enc2], dim=1))
v2 = hidden_states_out[6]
enc3 = self.encoder3(self.proj_feat(v2, hidden_size, feat_size))
x3 = self.conv3(torch.cat([self.pool(x2), enc3], dim=1))
v3 = hidden_states_out[9]
enc4 = self.encoder4(self.proj_feat(v3, hidden_size, feat_size))
x4 = self.conv4(torch.cat([self.pool(x3), enc4], dim=1))
enc5 = self.encoder5(self.proj_feat(v4, hidden_size, feat_size))
x5 = self.conv5(torch.cat([self.pool(x4), enc5], dim=1))
# Decoder
# Layer1
# 1
x50 = self.up(x5)
xd4 = self.Att4(g=x50, x=x4)
d1 = self.deconv1(torch.cat([xd4, x50], 1))
# 2
x51 = self.up(x5)
xd41 = self.Att41(g=x51, x=x4)
d11 = self.deconv11(torch.cat([xd41, x51], 1))
# Layer2
#1
x40 = self.up(d1)
xd3 = self.Att3(g=x40, x=x3)
d2 = self.deconv2(torch.cat([xd3, x40], 1))
#2
x41 = self.up(d11)
xd31 = self.Att31(g=x41, x=x3)
d21 = self.deconv21(torch.cat([xd31, x41], 1))
# Layer3
#1
x30 = self.up(d2)
xd2 = self.Att2(g=x30, x=x2)
d3 = self.deconv3(torch.cat([xd2, x30], 1))
#2
x31 = self.up(d21)
xd21 = self.Att21(g=x31, x=x2)
d31 = self.deconv31(torch.cat([xd21, x31], 1))
# Layer4
#1
x20 = self.up(d3)
xd1 = self.Att1(g=x20, x=x1)
d4 = self.deconv4(torch.cat([xd1, x20], 1))
#2
x21 = self.up(d31)
xd11 = self.Att11(g=x21, x=x1)
d41 = self.deconv41(torch.cat([xd11, x21], 1))
output = self.final(d4)
output2 = self.final1(d41)
return output, output2
model = DTrAttUnet(in_channels=3, out_channels=1, img_size=224)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
total_params = trainable_params + non_trainable_params
print(f"Trainable params: {trainable_params:,}")
print(f"Non-trainable params: {non_trainable_params:,}")
print(f"Total params: {total_params:,}")
Код архитектуры взят непосредственно из файла Architecture.py.
Вы можете попробовать его напрямую с помощью colab.
Мой вопрос:
Может ли быть разница в том, как определенные компоненты MONAI или ViT инициализируются или подсчитываются в более новых версиях?
Я упускаю детали конфигурации или изменения в архитектуре, которые могли бы объяснить эту разницу?
Действительно ли общие параметры в их публичной реализации превышают 100 M?
Подробнее здесь:
https://stackoverflow.com/questions/798 ... ementation
1762424353
Anonymous
В настоящее время я воспроизвожу архитектуру [b]D-TrAttUnet[/b] из статьи: D-TrAttUnet: На пути к гибридной архитектуре CNN-трансформатора для общей и тонкой сегментации медицинских изображений https://doi.org/10.1016/j.compbiomed.2024.108590 Я следовал официальной реализации, представленной в их репозитории GitHub: https://github.com/faresbougourzi/D-TrAttUnet Согласно статье, общее количество параметров модели составляет [b]70,13 миллиона[/b]: [img]https://i.sstatic.net/LhSFbYLd.png[/img] Однако, когда я реализую архитектуру (используя тот же код из Architecture.py), я последовательно получаю [b]более 100 миллионов параметров[/b] (около 104–132M, в зависимости от среды). Я подозреваю, что несоответствие может быть связано с различиями версий используемых библиотек [b]MONAI[/b] или [b]PyTorch[/b], поскольку статья была опубликована в 2024 году. Я пробовал несколько версий MONAI как из [b]2023[/b], так и из [b]2024[/b]. (например, monai==1.1.0, 1.2.0, 1.3.0) и версии PyTorch, такие как 2.0.0, 2.1.0 и 2.2.0, но количество параметров остается значительно выше, чем сообщается. [code]!pip install --quiet monai import torch import torch.nn as nn from itertools import repeat from typing import Union, Sequence from monai.networks.blocks.transformerblock import TransformerBlock class PatchEmbeddingBlock(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): super().__init__() img_size = tuple(repeat(img_size, 2)) patch_size = tuple(repeat(patch_size, 2)) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): B, C, H, W = x.shape x = self.proj(x).flatten(2).transpose(1, 2) return x ############################################################################## # Transformer Path class ViT(nn.Module): def __init__( self, in_channels: int, img_size: Union[Sequence[int], int], patch_size: Union[Sequence[int], int], hidden_size: int = 768, mlp_dim: int = 3072, num_layers: int = 12, num_heads: int = 12, dropout_rate: float = 0.0, ): super().__init__() self.patch_embedding = PatchEmbeddingBlock() self.blocks = nn.ModuleList( [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] ) self.norm = nn.LayerNorm(hidden_size) def forward(self, x): x = self.patch_embedding(x) hidden_states_out = [] for blk in self.blocks: x = blk(x) hidden_states_out.append(x) x = self.norm(x) return x, hidden_states_out ############################################################################## # Attention Block class Attention_block(nn.Module): def __init__(self,F_g,F_l,F_int): super(Attention_block,self).__init__() self.W_g = nn.Sequential( nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), nn.BatchNorm2d(F_int) ) self.W_x = nn.Sequential( nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), nn.BatchNorm2d(F_int) ) self.psi = nn.Sequential( nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), nn.BatchNorm2d(1), nn.Sigmoid() ) self.relu = nn.ReLU(inplace=True) def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) psi = self.relu(g1+x1) psi = self.psi(psi) return x*psi ############################################################################## # ResBlock class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) self.skip = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) def forward(self, x): return self.conv(x) + self.skip(x) ############################################################################## # UpResBlock class UPDoubleConv(nn.Module): def __init__(self, in_channels, out_channels, num_layer): super(UPDoubleConv, self).__init__() self.deconv = nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), DoubleConv(in_channels, out_channels) ) self.blocks = nn.ModuleList( [ nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), DoubleConv(out_channels, out_channels) ) for i in range(num_layer) ] ) def forward(self, x): x = self.deconv(x) for blk in self.blocks: x = blk(x) return x ############################################################################## # D-TrAttUnet Architecture class DTrAttUnet(nn.Module): def __init__( self, in_channels: int, out_channels: int, img_size: Union[Sequence[int], int], feature_size: int = 16, hidden_size: int = 768, mlp_dim: int = 3072, num_heads: int = 12): super().__init__() self.hidden_size = hidden_size self.classification = False self.pool = nn.MaxPool2d(2, 2) self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.vit = ViT(in_channels=3,img_size=224,patch_size=16) nb_filter = [32, 64, 128, 256, 512] self.encoder1 = DoubleConv(in_channels, nb_filter[0]) self.conv1 = DoubleConv(in_channels, nb_filter[0]) self.conv2 = DoubleConv(nb_filter[1], nb_filter[1]) self.conv3 = DoubleConv(nb_filter[2], nb_filter[2]) self.conv4 = DoubleConv(nb_filter[3], nb_filter[3]) self.conv5 = DoubleConv(nb_filter[4], nb_filter[4]) self.encoder2 = UPDoubleConv( in_channels = hidden_size, out_channels = nb_filter[0], num_layer = 2 ) self.encoder3 = UPDoubleConv( in_channels = hidden_size, out_channels = nb_filter[1], num_layer = 1 ) self.encoder4 = UPDoubleConv( in_channels = hidden_size, out_channels = nb_filter[2], num_layer = 0 ) self.encoder5 = DoubleConv(hidden_size, nb_filter[3]) # Decoder1 self.Att4 = Attention_block(F_g= nb_filter[4], F_l=nb_filter[3], F_int=nb_filter[3]) self.Att3 = Attention_block(F_g= nb_filter[3], F_l=nb_filter[2], F_int=nb_filter[2]) self.Att2 = Attention_block(F_g= nb_filter[2], F_l=nb_filter[1], F_int=nb_filter[1]) self.Att1 = Attention_block(F_g= nb_filter[1], F_l=nb_filter[0], F_int=nb_filter[0]) self.deconv1 = DoubleConv(nb_filter[3]+nb_filter[4], nb_filter[3]) self.deconv2 = DoubleConv(nb_filter[2]+nb_filter[3], nb_filter[2]) self.deconv3 = DoubleConv(nb_filter[1]+nb_filter[2], nb_filter[1]) self.deconv4 = DoubleConv(nb_filter[0]+nb_filter[1], nb_filter[0]) self.final = nn.Conv2d(nb_filter[0], out_channels, kernel_size=1) # Decoder2 self.Att41 = Attention_block(F_g= nb_filter[4], F_l=nb_filter[3], F_int=nb_filter[3]) self.Att31 = Attention_block(F_g= nb_filter[3], F_l=nb_filter[2], F_int=nb_filter[2]) self.Att21 = Attention_block(F_g= nb_filter[2], F_l=nb_filter[1], F_int=nb_filter[1]) self.Att11 = Attention_block(F_g= nb_filter[1], F_l=nb_filter[0], F_int=nb_filter[0]) self.deconv11 = DoubleConv(nb_filter[3]+nb_filter[4], nb_filter[3]) self.deconv21 = DoubleConv(nb_filter[2]+nb_filter[3], nb_filter[2]) self.deconv31 = DoubleConv(nb_filter[1]+nb_filter[2], nb_filter[1]) self.deconv41 = DoubleConv(nb_filter[0]+nb_filter[1], nb_filter[0]) self.final1 = nn.Conv2d(nb_filter[0], 1, kernel_size=1) def proj_feat(self, x, hidden_size, feat_size): x = x.view(x.size(0), feat_size, feat_size, hidden_size) x = x.permute(0, 3, 1, 2).contiguous() return x def forward(self, x_in): v4, hidden_states_out = self.vit(x_in) feat_size = 14 hidden_size = 768 x1 = self.conv1(x_in) v1 = hidden_states_out[3] enc2 = self.encoder2(self.proj_feat(v1, hidden_size ,feat_size)) x2 = self.conv2(torch.cat([self.pool(x1), enc2], dim=1)) v2 = hidden_states_out[6] enc3 = self.encoder3(self.proj_feat(v2, hidden_size, feat_size)) x3 = self.conv3(torch.cat([self.pool(x2), enc3], dim=1)) v3 = hidden_states_out[9] enc4 = self.encoder4(self.proj_feat(v3, hidden_size, feat_size)) x4 = self.conv4(torch.cat([self.pool(x3), enc4], dim=1)) enc5 = self.encoder5(self.proj_feat(v4, hidden_size, feat_size)) x5 = self.conv5(torch.cat([self.pool(x4), enc5], dim=1)) # Decoder # Layer1 # 1 x50 = self.up(x5) xd4 = self.Att4(g=x50, x=x4) d1 = self.deconv1(torch.cat([xd4, x50], 1)) # 2 x51 = self.up(x5) xd41 = self.Att41(g=x51, x=x4) d11 = self.deconv11(torch.cat([xd41, x51], 1)) # Layer2 #1 x40 = self.up(d1) xd3 = self.Att3(g=x40, x=x3) d2 = self.deconv2(torch.cat([xd3, x40], 1)) #2 x41 = self.up(d11) xd31 = self.Att31(g=x41, x=x3) d21 = self.deconv21(torch.cat([xd31, x41], 1)) # Layer3 #1 x30 = self.up(d2) xd2 = self.Att2(g=x30, x=x2) d3 = self.deconv3(torch.cat([xd2, x30], 1)) #2 x31 = self.up(d21) xd21 = self.Att21(g=x31, x=x2) d31 = self.deconv31(torch.cat([xd21, x31], 1)) # Layer4 #1 x20 = self.up(d3) xd1 = self.Att1(g=x20, x=x1) d4 = self.deconv4(torch.cat([xd1, x20], 1)) #2 x21 = self.up(d31) xd11 = self.Att11(g=x21, x=x1) d41 = self.deconv41(torch.cat([xd11, x21], 1)) output = self.final(d4) output2 = self.final1(d41) return output, output2 model = DTrAttUnet(in_channels=3, out_channels=1, img_size=224) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad) total_params = trainable_params + non_trainable_params print(f"Trainable params: {trainable_params:,}") print(f"Non-trainable params: {non_trainable_params:,}") print(f"Total params: {total_params:,}") [/code] Код архитектуры взят непосредственно из файла Architecture.py. Вы можете попробовать его напрямую с помощью colab. Мой вопрос: Может ли быть разница в том, как определенные компоненты MONAI или ViT инициализируются или подсчитываются в более новых версиях? Я упускаю детали конфигурации или изменения в архитектуре, которые могли бы объяснить эту разницу? Действительно ли общие параметры в их публичной реализации превышают 100 M? Подробнее здесь: [url]https://stackoverflow.com/questions/79811131/discrepancy-in-model-parameters-70-13m-in-paper-vs-100m-in-implementation[/url]