Отслеживание прогресса Pymc в стримлитеPython

Программы на Python
Ответить
Anonymous
 Отслеживание прогресса Pymc в стримлите

Сообщение Anonymous »

У меня есть проект Streamlit для разработки моделей PyMC с использованием его виджетов в браузере. Работает очень хорошо, но в момент запуска семплера между компиляцией моделей и окончанием семплирования может быть большая задержка. Я хотел бы предупредить пользователя о том, что процесс работает должным образом, и, если возможно, получить представление о скорости.
Я пытался связать индикаторы выполнения pymc иstreamlit с обратным вызовом... Однако у меня не получается заставить его работать... Есть ли у кого-нибудь исправление или дополнительный подход к отслеживанию прогресса?
Большое спасибо за любые отзывы.
import streamlit as st
import pymc as pm
from pymc.progress_bar import ProgressBarManager

st.title("PyMC + Nutpie Sampler")

n_draws = 1000
n_tune = 1000
n_chains = 4

if st.button("Run Sampling"):

total_steps = n_draws * n_chains
chain_draws = {i: 0 for i in range(n_chains)}

progress_bar = st.progress(0, text=f"Sampling for {n_chains * n_draws} steps...")

old_update = ProgressBarManager.update

def new_update(self, chain_idx, is_last, draw, tuning, stats):
if not tuning:
chain_draws[chain_idx] += 1
completed = sum(chain_draws.values())
progress = min(completed / total_steps, 1.0)
progress_bar.progress(
progress,
text=f"Sampling... {completed}/{total_steps} steps ({progress * 100:.1f}%)"
)
old_update(self, chain_idx, is_last, draw, tuning, stats)

ProgressBarManager.update = new_update

with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1)
obs = pm.Normal("obs", mu=mu, sigma=1, observed=[1, 2, 3])

trace = pm.sample(
draws=n_draws,
tune=n_tune,
chains=n_chains,
nuts_sampler="nutpie",
nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"},
)

# Restore original to avoid side effects on reruns
ProgressBarManager.update = old_update

progress_bar.progress(1.0, text="Sampling complete!")
st.success("Sampling complete!")

st.subheader("Posterior Summary")
st.dataframe(pm.stats.summary(trace))
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»