Код: Выделить всё
embeds_2d: pd.DataFrame = visualizer(topics)
data = embeds_2d[["x", "y"]].values
labels = embeds_2d["labels"].values
with st.spinner("Rendering the image..."):
with tempfile.TemporaryDirectory() as tmp_dir:
fig, ax = datamapplot.create_plot(
data, labels,
noise_label="Unimportant topics",
title="Topic Map",
sub_title="Visual representation of topics found within the book",
label_font_size=11,
)
fig.savefig(f"{tmp_dir}/saved_img.jpg")
matplotlib_fig = PIL.Image.open(f"{tmp_dir}/saved_img.jpg")
st.plotly_chart(matplotlib_fig)
Я подумал, что что-то не так с самим классом datamapplot, поэтому попробовал запустить его локально.
Все сработало нормально.
Тогда я подумал: поскольку он не работает в приложенииstreamlit, а работает локально, возможно, я мог бы добавить конечную точку к существующему серверу, который я создаю в fastapi, и посмотреть, даст ли это какие-либо результаты.
Поэтому я создал следующую конечную точку на своем существующем сервере.
Код: Выделить всё
@app.post("/create_img")
async def create_image(img_model: DataMapPlotInputModel) -> GeneratedImageModel:
with tempfile.TemporaryDirectory() as tmp_dir:
data = np.hstack((np.array(img_model.X_col).reshape(-1, 1),
np.array(img_model.Y_col).reshape(-1, 1)))
labels = np.array(img_model.labels).reshape(-1,)
fig, _ = datamapplot.create_plot(
data, labels,
noise_label="Unimportant topics",
title="Topic Map",
sub_title="Visual representation of topics found within the book",
label_font_size=11,
)
fig.savefig(f"{tmp_dir}/saved_img.jpg")
matplotlib_fig = PIL.Image.open(f"{tmp_dir}/saved_img.jpg")
return GeneratedImageModel(
encoded_image=encode_image(matplotlib_fig)
)
Код: Выделить всё
import numpy as np
import requests
import io
import matplotlib.pyplot as plt
from image_utils.decoder import decode_image
import matplotlib
matplotlib.rcParams["figure.dpi"] = 72
URL: str = "http://127.0.0.1:8000"
data_map_file = requests.get("https://github.com/TutteInstitute/datamapplot/raw/main/examples/Wikipedia-data_map.npy")
wikipedia_data_map = np.load(io.BytesIO(data_map_file.content))
label_file = requests.get(
"https://github.com/TutteInstitute/datamapplot/raw/main/examples/Wikipedia-cluster_labels.npy")
wikipedia_labels = np.load(io.BytesIO(label_file.content), allow_pickle=True)
x_col = wikipedia_data_map[:, 0].tolist()[0:10]
y_col = wikipedia_data_map[:, 1].tolist()[0:10]
labels = wikipedia_labels.tolist()[0:10]
response = requests.post(URL + "/create_img", json={
"X_col":x_col,
"Y_col": y_col,
"labels": labels,
}, timeout=1000)
img = decode_image(response.json()["encoded_image"])
print(np.asarray(img))
Подробнее здесь: https://stackoverflow.com/questions/791 ... s-crashing