Я пытаюсь создать модель TensorFlow, которая применяет обучаемый фильтр верхних частот в частотной области к изображению ( shape: 256,256,3 ), уточняет его с помощью CNN, а затем применяет к исходному изображению. Я хочу, чтобы радиус r фильтра верхних частот был изучен во время обучения.
# --- Load images ---
image_path = "/content/IMG-0002-00013.jpg"
mask_path = "/content/IMG-0002-00013.jpg"
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE).astype(np.float32)/255.0
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE).astype(np.float32)/255.0
image = image[np.newaxis, ..., np.newaxis] # (1,H,W,1)
mask = mask[np.newaxis, ..., np.newaxis]
# --- Learnable high-pass in frequency domain ---
class LearnableHighPassFreq(tf.keras.layers.Layer):
def __init__(self, init_radius=10.0, slope=2.0):
super().__init__()
# radius in log-space so it's always positive
self.log_radius = tf.Variable(np.log(init_radius), dtype=tf.float32, trainable=True)
self.slope = slope # soft sigmoid slope
def call(self, x):
batch, h, w, c = tf.unstack(tf.shape(x))
radius = tf.exp(self.log_radius)
# Frequency grid
u = tf.range(-h//2, h//2, dtype=tf.float32)
v = tf.range(-w//2, w//2, dtype=tf.float32)
U, V = tf.meshgrid(u, v, indexing='ij')
D = tf.sqrt(U**2 + V**2)
# Soft high-pass filter in frequency domain
H = tf.sigmoid(self.slope * (D - radius))
H = tf.expand_dims(H, axis=0) # batch dim
H = tf.expand_dims(H, axis=-1) # channel dim
# Apply filter to FFT of the image
x_fft = tf.signal.fft2d(tf.cast(x, tf.complex64))
x_filtered = tf.signal.ifft2d(x_fft * tf.cast(H, tf.complex64))
# Return real spatial-domain high-pass feature map
return tf.math.real(x_filtered)
# --- CNN refinement for high-pass features ---
def build_refinement_cnn():
inputs = layers.Input(shape=(None, None, 1))
x = layers.Conv2D(32,3,padding='same',activation='relu')(inputs)
x = layers.Conv2D(32,3,padding='same',activation='relu')(x)
x = layers.Conv2D(1,1,activation='sigmoid')(x) # refined filter
return models.Model(inputs,x)
# --- Frequency-aware loss ---
def frequency_aware_loss(y_true, y_pred, alpha=0.1):
# Soft Dice
y_true_f = tf.reshape(y_true,[-1])
y_pred_f = tf.reshape(y_pred,[-1])
intersection = tf.reduce_sum(y_true_f * y_pred_f)
dice_loss = 1 - (2.*intersection + 1e-6)/(tf.reduce_sum(y_true_f)+tf.reduce_sum(y_pred_f)+1e-6)
# FFT frequency-aware
y_true_fft = tf.signal.fft2d(tf.cast(y_true[...,0],tf.complex64))
y_pred_fft = tf.signal.fft2d(tf.cast(y_pred[...,0],tf.complex64))
h, w = tf.shape(y_true)[1], tf.shape(y_true)[2]
h = tf.cast(h, tf.float32)
w = tf.cast(w, tf.float32)
u = tf.range(tf.shape(y_true)[1], dtype=tf.float32)
v = tf.range(tf.shape(y_true)[2], dtype=tf.float32)
U,V = tf.meshgrid(u,v,indexing='ij')
D = tf.sqrt((U-h/2)**2 + (V-w/2)**2)
D = D/tf.reduce_max(D)
D = tf.expand_dims(D, axis=0)
freq_loss = tf.reduce_mean(tf.abs(tf.cast(y_pred_fft - y_true_fft, tf.float32)*D))
return dice_loss + alpha*freq_loss
# --- Build complete model ---
inputs = layers.Input(shape=(None,None,1))
# Step 1: learnable high-pass
high_pass = LearnableHighPassFreq()(inputs)
# Step 2: refine high-pass features
refined_filter = build_refinement_cnn()(high_pass)
# Step 3: apply refined filter to original image
masked_image = layers.Multiply()([inputs, refined_filter])
model = models.Model(inputs, masked_image)
# Compile and train
model.compile(optimizer=optimizers.Adam(1e-3), loss=frequency_aware_loss)
model.fit(image, mask, epochs=100, verbose=2)
# Print learned radius
learned_r = tf.exp(model.layers[1].log_radius).numpy()
print("Learned frequency-domain radius r:", learned_r)
Проблема в том, что после обучения изученный радиус r никогда не меняется по сравнению со своим начальным значением (10,0).
Итак, почему же радиус r не обновляется во время обучения? Это из-за проблем с градиентом, обучения одному изображению или чего-то еще? Как я могу сделать обучаемый радиус частотной области действительно обучаемым?
Я пытаюсь создать модель TensorFlow, которая применяет обучаемый фильтр верхних частот в частотной области к изображению ( shape: 256,256,3 ), уточняет его с помощью CNN, а затем применяет к исходному изображению. Я хочу, чтобы радиус r фильтра верхних частот был изучен во время обучения. [code]# --- Load images --- image_path = "/content/IMG-0002-00013.jpg" mask_path = "/content/IMG-0002-00013.jpg"
# Frequency grid u = tf.range(-h//2, h//2, dtype=tf.float32) v = tf.range(-w//2, w//2, dtype=tf.float32) U, V = tf.meshgrid(u, v, indexing='ij') D = tf.sqrt(U**2 + V**2)
# Soft high-pass filter in frequency domain H = tf.sigmoid(self.slope * (D - radius)) H = tf.expand_dims(H, axis=0) # batch dim H = tf.expand_dims(H, axis=-1) # channel dim
# Apply filter to FFT of the image x_fft = tf.signal.fft2d(tf.cast(x, tf.complex64)) x_filtered = tf.signal.ifft2d(x_fft * tf.cast(H, tf.complex64))
# Return real spatial-domain high-pass feature map return tf.math.real(x_filtered)
# --- CNN refinement for high-pass features --- def build_refinement_cnn(): inputs = layers.Input(shape=(None, None, 1)) x = layers.Conv2D(32,3,padding='same',activation='relu')(inputs) x = layers.Conv2D(32,3,padding='same',activation='relu')(x) x = layers.Conv2D(1,1,activation='sigmoid')(x) # refined filter return models.Model(inputs,x)
[/code] Проблема в том, что после обучения изученный радиус r никогда не меняется по сравнению со своим начальным значением (10,0). Итак, почему же радиус r не обновляется во время обучения? Это из-за проблем с градиентом, обучения одному изображению или чего-то еще? Как я могу сделать обучаемый радиус частотной области действительно обучаемым?