Как использовать pmap с массивами переменной длины в JAX? ⇐ Python
-
Anonymous
Как использовать pmap с массивами переменной длины в JAX?
Я работаю над реализацией JAX, которая предполагает распараллеливание вычислений над массивами переменной длины. Однако я столкнулся с ошибкой ValueError: pmap получил несовместимые размеры для отображаемых осей массива.
Вот мой фрагмент кода:
def Pair_wise_term_pmap(num_users, user_ids, Movie_ids, рейтинги, U_mean, V_mean, b_mean_v): # Предварительно подготовьте маски и массивы user_masks = [user_ids == i для i в диапазоне (num_users)] user_movie_ids_list = [movie_ids[mask].reshape(-1, 1) для маски в user_masks] user_ratings_list = [рейтинги[маска] для маски в user_masks] # Распараллеливаем цикл с помощью pmap def Parallel_body (i, user_movie_ids, user_ratings): u_m = U_mean.reshape(-1, 1) V_m = V_mean[user_movie_ids].reshape(len(user_movie_ids), -1) b_n = b_mean_v[user_movie_ids].reshape(-1, 1) # Код, использующий u_m, V_m, b_n вернуть что-то результат = pmap(parallel_body)(jnp.arange(num_users), user_movie_ids_list, user_ratings_list) вернуть jnp.sum(partial_errors)
Ошибка возникает при попытке использовать pmap с массивами user_movie_ids_list и user_ratings_list, поскольку они имеют разную длину.
Я ожидал, что код будет работать нормально, однако получил распечатку следующего содержания:
ValueError: pmap получил несовместимые размеры отображаемых осей массива: * большинство топоров (из них 16) имели размер 35, напр. ось 0 аргументаnested_user_data[70][0] типа int32[35,1]; * некоторые оси (их 14) имели размер 26, напр. ось 0 аргументаnested_user_data[24][0] типа int32[26,1]; * некоторые оси (их 14) имели размер 21, напр. ось 0 аргументаnested_user_data[25][0] типа int32[21,1]; * некоторые оси (их 14) имели размер 20, напр. ось 0 аргументаnested_user_data[52][0] типа int32[20,1];
Я работаю над реализацией JAX, которая предполагает распараллеливание вычислений над массивами переменной длины. Однако я столкнулся с ошибкой ValueError: pmap получил несовместимые размеры для отображаемых осей массива.
Вот мой фрагмент кода:
def Pair_wise_term_pmap(num_users, user_ids, Movie_ids, рейтинги, U_mean, V_mean, b_mean_v): # Предварительно подготовьте маски и массивы user_masks = [user_ids == i для i в диапазоне (num_users)] user_movie_ids_list = [movie_ids[mask].reshape(-1, 1) для маски в user_masks] user_ratings_list = [рейтинги[маска] для маски в user_masks] # Распараллеливаем цикл с помощью pmap def Parallel_body (i, user_movie_ids, user_ratings): u_m = U_mean.reshape(-1, 1) V_m = V_mean[user_movie_ids].reshape(len(user_movie_ids), -1) b_n = b_mean_v[user_movie_ids].reshape(-1, 1) # Код, использующий u_m, V_m, b_n вернуть что-то результат = pmap(parallel_body)(jnp.arange(num_users), user_movie_ids_list, user_ratings_list) вернуть jnp.sum(partial_errors)
Ошибка возникает при попытке использовать pmap с массивами user_movie_ids_list и user_ratings_list, поскольку они имеют разную длину.
Я ожидал, что код будет работать нормально, однако получил распечатку следующего содержания:
ValueError: pmap получил несовместимые размеры отображаемых осей массива: * большинство топоров (из них 16) имели размер 35, напр. ось 0 аргументаnested_user_data[70][0] типа int32[35,1]; * некоторые оси (их 14) имели размер 26, напр. ось 0 аргументаnested_user_data[24][0] типа int32[26,1]; * некоторые оси (их 14) имели размер 21, напр. ось 0 аргументаnested_user_data[25][0] типа int32[21,1]; * некоторые оси (их 14) имели размер 20, напр. ось 0 аргументаnested_user_data[52][0] типа int32[20,1];
Мобильная версия