Джакс отбрасывает код kd-дерева, занимая невероятно много времениPython

Программы на Python
Ответить Пред. темаСлед. тема
Anonymous
 Джакс отбрасывает код kd-дерева, занимая невероятно много времени

Сообщение Anonymous »

Я загнал себя в угол следующей ситуацией:
  • Я использую оптимизатор, для работы которого требуются плавные градиенты, и я Я использую Jax для автоматического дифференцирования. Поскольку этот код представляет собой Jax-jit, это означает, что все, что с ним связано, должно быть отслеживаемым Jax-jit.
  • Мне нужно интерполировать функцию для использования с оптимизатором, но я не могу использовать Библиотека Scipy, поскольку она несовместима с Jax (есть реализация jax.scipy.interpolate.RegularGridInterpolator, но она не гладкая — она поддерживает только линейную интерполяцию и интерполяцию ближайшего соседа).
  • Это означает, что мне нужно написать свой собственный Jax-совместимый плавный интерполятор, основанный на коде Scipy RBFInterpolator. Реализация этого очень хороша — она использует kd-дерево для поиска ближайших соседей запрашиваемой точки в пространстве, а затем использует их для построения локальной интерполяции. Это означает, что мне также нужно написать класс kd-tree, совместимый с Jax (класс Scipy также несовместим с Jax), что я и сделал.
Проблема возникает при jit-компиляции кода kd-дерева. Я написал его «стандартным способом», используя объекты для узлов дерева с полями левого и правого узла для дочерних элементов. В конечных узлах эти поля имеют значения None, что указывает на отсутствие дочерних узлов.
Код выполняется и функционально корректен, однако его jit-компиляция занимает много времени: 72 секунды для дерева из 64 координат, 131 секунда для 343 координат... и мой предполагаемый набор данных содержит более 14 миллионов точек. Я думаю, что внутренне Джекс прослеживает каждый возможный путь в дереве, поэтому это занимает так много времени. Результаты показывают, что это невероятно быстро: 0,0075 с для извлечения по 10 точкам kd-дерева против 0,4 с для перебора по всем точкам (для 343 точек). Именно такие скорости я надеюсь получить для использования в оптимизаторе (без тряски это будет слишком медленно). Однако это кажется невозможным, если время компиляции будет продолжать расти, как мы уже видели.
Я подумал, что проблема может заключаться в структуре дерева с множеством различных объектов. для хранения, поэтому мы также реализовали алгоритм поиска по kd-дереву, где дерево представлено набором массивов Jax-numpy (например, координата, значение, left и right; где каждый индекс соответствует точке в дереве), а для поиска по дереву используется итерация, а не рекурсия (это было непросто, но это работает!). Однако преобразование этого для работы с jit (изменение операторов if для jax.lax.cond) будет сложным, и прежде чем начать, я задавался вопросом, стоит ли оно того - наверняка у меня будет та же проблема: Jax будет отслеживать все ветви дерева до тех пор, пока не будут достигнуты «нулевые терминаторы» (значения -1 в левом и правом массиве), и это все равно займет очень много времени. скомпилировать. Я исследовал такие структуры, как jax.lax. while_loop, на случай, если они могут помочь?
(Я также написал гибрид двух подходов с массивом -дерево и алгоритм, основанный на рекурсии. В этом случае трассировка переходит в бесконечный цикл, я думаю, из-за того, что нулевой терминатор равен -1, а не None. Но массивы должны быть известны статически (они этого не делают). t меняются после построения и принадлежат объекту, который помечен как статический ввод), так что, возможно, решение кроется в этом, и я делаю что-то не так.)
Мне было интересно, если Я делаю что-то явно неправильное (или если мое понимание неверно), и могу ли я что-нибудь сделать, чтобы ускорить это? Можно ли ожидать, что время компиляции будет таким большим, когда нужно отследить так много путей кода? Я не думаю, что смогу создать jitted-функцию только один раз, а затем сохранить ее?
Меня беспокоит, что единственное решение может состоять в том, чтобы переписать код оптимизатора так, чтобы он не не использовать Jax (например, если я жестко закодирую производные и перепишу часть кода так, чтобы он работал непосредственно с массивами, а не векторизовался по входным данным).
Код доступен здесь: https://github.com/FluffyCodeMonster/jax_kd_tree
Даны все три описанные разновидности: дерево на основе узлов с рекурсией, дерево на основе массива с итерацией и массив -основанное дерево с рекурсией. Первый работает, но компилируется очень медленно по мере увеличения числа точек в дереве; второй тоже работает, но еще не написан в удобном виде. Последний написан для jit-компиляции, но не может jit-компилировать, поскольку попадает в бесконечную рекурсию.
Мне действительно нужно срочно заставить это работать, чтобы я мог получить результаты оптимизации.

Подробнее здесь: https://stackoverflow.com/questions/787 ... nt-of-time
Реклама
Ответить Пред. темаСлед. тема

Быстрый ответ

Изменение регистра текста: 
Смайлики
:) :( :oops: :roll: :wink: :muza: :clever: :sorry: :angel: :read: *x)
Ещё смайлики…
   
К этому ответу прикреплено по крайней мере одно вложение.

Если вы не хотите добавлять вложения, оставьте поля пустыми.

Максимально разрешённый размер вложения: 15 МБ.

  • Похожие темы
    Ответы
    Просмотры
    Последнее сообщение
  • Джакс отслеживает статический аргумент
    Anonymous » » в форуме Python
    0 Ответы
    34 Просмотры
    Последнее сообщение Anonymous
  • Джакс отслеживает статический аргумент
    Anonymous » » в форуме Python
    0 Ответы
    17 Просмотры
    Последнее сообщение Anonymous
  • JDK 11; ДЖАКС-WS; Поставщик com.sun.xml.internal.ws.spi.ProviderImpl не найден
    Anonymous » » в форуме JAVA
    0 Ответы
    61 Просмотры
    Последнее сообщение Anonymous
  • Джакс постоянный кэш -разрывы вызывает недостаток?
    Anonymous » » в форуме Python
    0 Ответы
    9 Просмотры
    Последнее сообщение Anonymous
  • SVG CSS отбрасывает тень при наведении курсора мыши
    Anonymous » » в форуме CSS
    0 Ответы
    24 Просмотры
    Последнее сообщение Anonymous

Вернуться в «Python»