Попытка реализовать векторизованную версию алгоритма (из вычислительной геометрии) с использованием Jax. Я сделал минимальный рабочий пример, используя LinkedList, чтобы конкретно выразить мой запрос (в противном случае я использую DCEL).
Идея состоит в том, что этот векторизованный алгоритм будет проверять определенные критерии по DCEL. Для простоты я заменил эту «процедуру проверки критериев» простым алгоритмом суммирования.
import jax
from jax import vmap
import jax.numpy as jnp
class Node:
# Constructor to initialize the node object
def __init__(self, data):
self.data = data
self.next = None
class LinkedList:
def __init__(self):
self.head = None
def push(self, new_data):
new_node = Node(new_data)
new_node.next = self.head
self.head = new_node
def printList(self):
temp = self.head
while(temp):
print (temp.data,end = " ")
temp = temp.next
def summate(list) :
prev = None
current = list.head
sum = 0
while(current is not None):
sum += current.data
next = current.next
current = next
return sum
list1 = LinkedList()
list1.push(20)
list1.push(4)
list1.push(15)
list1.push(85)
list2 = LinkedList()
list2.push(19)
list2.push(13)
list2.push(2)
list2.push(13)
#list(map(summate, ([list1, list2])))
vmap(summate)(jnp.array([list1, list2]))
Я получаю следующую ошибку.
TypeError: Value '<__main__.LinkedList object at 0x1193799d0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
Цель состоит в том, что если у меня есть набор, скажем, 10 000 Linkedlists, я смогу применить эту функцию суммирования к каждому LinkedList в векторизованном виде. Я реализовал то, что хотел, на базовом языке Python, но хочу сделать это в Jax, поскольку существует более крупная вероятностная функция, для которой я буду использовать эту подпроцедуру (это цепь Маркова).
Возможно, я совершенно не могу работать с такими структурами данных через Jax, поскольку ошибка предполагает, что поддерживаются только числовые типы. Могу ли я каким-то образом использовать pytrees
, чтобы смягчить это ограничение?
Будет заманчиво предложить мне использовать простой список из jnp, но я использую Linkedlist просто как пример простой (st) структуры данных. Как упоминалось ранее, на самом деле я работаю над DCEL.
PS: код Linkedlist был взят с сайта GeeksForGeeks, так как хотелось быстро придумать минимальный рабочий пример.
🤔 А знаете ли вы, что...
Python имеет множество фреймворков для веб-разработки, такие как Django и Flask.
Цель состоит в том, что если у меня есть набор, скажем, 10 000 Linkedlists, я смогу применить эту функцию суммирования к каждому LinkedList в векторизованном виде.
Эта цель недостижима с использованием JAX. Вы можете зарегистрировать свой класс как собственный Pytree, чтобы он работал с функциями JAX (см. Расширение pytrees), но это не означает, что вы можете векторизовать операцию над списком таких объектов.
Преобразования JAX, такие как vmap
и jit
, работают для данных, хранящихся с шаблоном структуры массивов (например, один объект LinkedList
, содержащий массивы, которые представляют несколько пакетных связанных списков), а не шаблон массива структур (например, список из нескольких объектов LinkedList
). .
Кроме того, используемый вами алгоритм, основанный на цикле while
, несовместим с преобразованиями JAX (см. Резкие биты JAX: поток управления), а дерево узлов с динамическим размером не вписывается в ограничения статической формы JAX-программы.
Мне бы хотелось указать вам правильное направление, но я считаю, что вам нужно либо отказаться от использования JAX, либо отказаться от использования динамических связанных списков. Вы не сможете сделать и то, и другое.