Это упрощенный файл app.py, в котором я провожу тесты:
Код: Выделить всё
import os
from flask import Flask, request, jsonify
from flask_cors import CORS
from realesrgan import RealESRGANer
from PIL import Image
import numpy as np
import io
app = Flask(__name__)
CORS(app)
# Initialize the RealESRGAN model
model_path = "./RealESRGAN_x4plus_anime_6B.pth" # Path to the model that can be dowloaded here: https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth
model = RealESRGANer(scale=4, model_path=model_path) # Initialize the model with the specified scale
@app.route('/upscale', methods=['POST'])
def upscale_image():
if 'image' not in request.files:
return jsonify({"error": "No image provided"}), 400
file = request.files['image']
try:
# Read the image file
img = Image.open(file.stream).convert('RGB')
# Convert image to numpy array and upscale
img_array = np.array(img)
upscaled_image = model.predict(img_array)
# Convert upscaled image back to PIL format
upscaled_image = Image.fromarray(upscaled_image)
# Save the upscaled image to a byte stream
img_byte_arr = io.BytesIO()
upscaled_image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
return jsonify({"message": "Image upscaled successfully", "upscaled_image": img_byte_arr.getvalue().decode('latin1')})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(debug=True)
Код: Выделить всё
model.load_state_dict(loadnet[keyname], strict=True)
^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'load_state_dict'
Вот несколько вещей, которые я пытался сделать во время отладки:
Загрузка модели: Я попробовал загрузить модель «вручную» с помощью следующего кода:
Код: Выделить всё
# Detect the device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize RealESRGANer
self.model = RealESRGANer(
scale=4, # Upscale factor
model_path=model_path, # Model path for pretrained weights
tile=0, # No tiling
tile_pad=10, # Padding if tiling is used
pre_pad=0, # Pre-padding
half=False, # Avoid half precision issues
device=device # Use detected device (CUDA or CPU)
)
# Load the weights into the model architecture
state_dict = torch.load(model_path, map_location=device)
# Ensure 'params_ema' is present in the state_dict
if "params_ema" in state_dict:
self.model.model.load_state_dict(state_dict["params_ema"], strict=True)
else:
raise ValueError("The state_dict does not contain 'params_ema'. Please check the model file.")
# Set the model to evaluation mode (important for inference)
self.model.model.eval()
Код: Выделить всё
print(os.path.exists("./RealESRGAN_x4plus_anime_6B.pth"))
Подробнее здесь: https://stackoverflow.com/questions/791 ... -dict-when