Моя классификация А.И. модель не может превзойти определенный порог точности проверкиPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Моя классификация А.И. модель не может превзойти определенный порог точности проверки

Сообщение Anonymous »

Я недавно начинаю с А.И. И для проекта я хотел создать модель, которая может классифицировать, какую болезнь у вас есть на основе рентгеновского излучения грудной клетки, после написания модели и использования архитектуры Densenet (так как это было лучшее из EffifiveNet и Mobilenet), но после сохранения модели после каждого эпоха, изменяющего оптимизатор ADAM (его уровень обучения и норму клипа, в настоящее время на 0,001, 20.0 соответственно), модели сохраняют циркуляцию. Есть ли что -то, что я могу изменить, чтобы сделать мою модель лучше обобщать?clasification_data = pd.read_csv("/mnt/d/CXR8/PruneCXR/miccai2023_nih-cxr-lt_labels_train.csv",sep=",").values.tolist()
training_frequencies = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]
for clases in clasification_data:
for i in range(0,20):
training_frequencies = training_frequencies + clases[i+1]

training_labels = []
for clases in clasification_data:
training_labels.append(clases[1:-1])

total_samples = sum(training_frequencies)

# Calculate weights (inverse of frequency)
weights = [total_samples / freq for freq in training_frequencies]

# Normalize weights (optional)
normalized_weights = np.array(weights) / sum(weights) * len(training_frequencies)

class_weight_dict = {i: weight for i, weight in enumerate(normalized_weights)}

print("Class Weights:", class_weight_dict)

< /code>
model = DenseNet121(weights=None, include_top=False)
model = Model(inputs=model.input, outputs=Dense(20, activation="sigmoid")(GlobalAveragePooling2D()(model.output)))

model.load_weights('./checkpoints/weights_epoch_24.weights.h5')
adam_optimizer = Adam(
learning_rate=0.00001, # Default: 0.001
clipnorm = 20.0
)
model.compile(optimizer=adam_optimizer, loss='binary_crossentropy', metrics=[Precision(),Accuracy()])
< /code>
def create_dynamic_loaded_data(csv_path):
dummy = pd.read_csv(csv_path,sep=",").values.tolist()
img_names = tf.convert_to_tensor([label[0] for label in dummy])
dummy = np.array([np.array(label[1:-1]) for label in dummy], dtype=np.float32)

general_img_path = "/mnt/d/CXR8/images/"

def preprocess_image(img_name, label):
image = tf.io.read_file(general_img_path + img_name)
image = tf.image.decode_jpeg(image, channels=3) # Ensure 3 channels
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE]) # Resize to match input size
image = image / 255.0 # Normalize to [0, 1]
return image, label

#create shuffle
dummy_data = tf.data.Dataset.from_tensor_slices((img_names, dummy))
dummy_data = dummy_data.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
dummy_data = dummy_data.shuffle(buffer_size=1000, seed=42)
dummy_data = dummy_data.batch(20).prefetch(tf.data.AUTOTUNE)

return dummy_data

train_dataset = create_dynamic_loaded_data("/mnt/d/CXR8/PruneCXR/miccai2023_nih-cxr-lt_labels_train.csv")
validation_dataset = create_dynamic_loaded_data("/mnt/d/CXR8/PruneCXR/miccai2023_nih-cxr-lt_labels_val.csv")
< /code>
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Define the checkpoint callback
checkpoint_path = os.path.join(checkpoint_dir, "weights_epoch_{epoch:02d}.weights.h5")
checkpoint_callback = ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True, # Save only the weights
save_best_only=False, # Save after every epoch
verbose=1
)

history = []
with tf.device('/GPU:0'):
history = model.fit(train_dataset,
validation_data=validation_dataset,
epochs=128,
verbose=1,
class_weight=class_weight_dict,
callbacks=[checkpoint_callback],
initial_epoch=24
)
< /code>
Just Tried with a 10^-6 learning rate, after the first epoch it gave a 0.457 but after that epoch it again plummeted around 0.44, i don't think lowering the learning rate again is a good solution, what else can i do(for this one i removed the clipnorm, i had to use it before cause i was having Nan loss)?
(the dataset that i am using: https://nihcc.app.box.com/v/ChestXray-NIHCC)

Подробнее здесь: https://stackoverflow.com/questions/795 ... n-threshol
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение

Вернуться в «Python»