Почему порядок возвращаемых переменных так сильно влияет на производительность функции jax jitted?Python

Программы на Python
Ответить
Anonymous
 Почему порядок возвращаемых переменных так сильно влияет на производительность функции jax jitted?

Сообщение Anonymous »

В jax вы можете передать аргумент функции, чтобы сэкономить память и время выполнения, если этот аргумент больше не используется.

Если вы знаете, что один из входных данных не нужен после вычисления, и если он соответствует форме и типу элемента одного из выходных данных, вы можете указать, что вы хотите, чтобы соответствующий входной буфер был передан для хранения выходных данных. Это уменьшит объем памяти, необходимой для выполнения, на размер пожертвованного буфера. отсюда.

Но очень странно, я обнаружил, что в моем коде ddpg порядок переданной переменной (совпадающий вывод) среди всех возвращаемых переменных может сильно повлиять на производительность. Ниже приведен сокращенный пример.
Как показано в следующем псевдокоде, за исключением порядка возвращаемых переменных, все идентично, но производительность сильно различается. Для меня это очень странно.
Псевдокод:

Код: Выделить всё

def train_one_step_return_later(key, model_params, buffer, buffer_state):
# sample data -> add to buffer -> sample from buffer -> update model
...
return model_params, buffer_state, key

def train_one_step_return_early(key, model_params, buffer, buffer_state):
model_params, buffer_state, key = train_one_step_return_later(
key, model_params, buffer, buffer_state)
return buffer_state, model_params, key

def benchmark_return_later():
# jit train_one_step_return_later and donate buffer_state
# warm up jitted train_one_step_return_later
# timing jitted train_one_step_return_later
...

def benchmark_return_early():
# jit train_one_step_return_early and donate buffer_state
# warm up jitted train_one_step_return_early
# timing jitted train_one_step_return_early
...

if __name__ == "__main__":
print("-------- return later ---------")
benchmark_return_later()
print("\n-------- return early ---------")
benchmark_return_early()
вывод:

Код: Выделить всё

-------- return later ---------
Average time: 638 microseconds

-------- return early ---------
Average time: 73 microseconds
Я пытался сократить код дальше, но любое существенное сокращение приводит к исчезновению этого явления. Например, если я сведу этап обновления модели к непосредственному генерированию градиента, а не к вычислениям на основе выборочных данных, разрыв в производительности исчезнет. Или, если я заменю буфер пустым массивом, он также исчезнет.
Полный код: см. здесь.

Подробнее здесь: https://stackoverflow.com/questions/798 ... formance-s
Ответить

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

Вернуться в «Python»