Реализация векторизованной функции над LinkedLists с использованием функции Jax vmap

Попытка реализовать векторизованную версию алгоритма (из вычислительной геометрии) с использованием Jax. Я сделал минимальный рабочий пример, используя LinkedList, чтобы конкретно выразить мой запрос (в противном случае я использую DCEL).

Идея состоит в том, что этот векторизованный алгоритм будет проверять определенные критерии по DCEL. Для простоты я заменил эту «процедуру проверки критериев» простым алгоритмом суммирования.


import jax
from jax import vmap
import jax.numpy as jnp

class Node: 
  
    # Constructor to initialize the node object 
    def __init__(self, data): 
        self.data = data 
        self.next = None

class LinkedList: 
  
    def __init__(self): 
        self.head = None
  
    def push(self, new_data): 
        new_node = Node(new_data) 
        new_node.next = self.head 
        self.head = new_node 

    def printList(self): 
        temp = self.head 
        while(temp): 
            print (temp.data,end = " ") 
            temp = temp.next

def summate(list) :
    prev = None
    current = list.head
    sum = 0
    while(current is not None): 
        sum += current.data
        next = current.next
        current = next
    return sum

list1 = LinkedList() 
list1.push(20) 
list1.push(4) 
list1.push(15) 
list1.push(85) 

list2 = LinkedList() 
list2.push(19)
list2.push(13)
list2.push(2)
list2.push(13)

#list(map(summate, ([list1, list2])))

vmap(summate)(jnp.array([list1, list2]))

Я получаю следующую ошибку.

TypeError: Value '<__main__.LinkedList object at 0x1193799d0>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

Цель состоит в том, что если у меня есть набор, скажем, 10 000 Linkedlists, я смогу применить эту функцию суммирования к каждому LinkedList в векторизованном виде. Я реализовал то, что хотел, на базовом языке Python, но хочу сделать это в Jax, поскольку существует более крупная вероятностная функция, для которой я буду использовать эту подпроцедуру (это цепь Маркова).

Возможно, я совершенно не могу работать с такими структурами данных через Jax, поскольку ошибка предполагает, что поддерживаются только числовые типы. Могу ли я каким-то образом использовать pytrees, чтобы смягчить это ограничение?

Будет заманчиво предложить мне использовать простой список из jnp, но я использую Linkedlist просто как пример простой (st) структуры данных. Как упоминалось ранее, на самом деле я работаю над DCEL.

PS: код Linkedlist был взят с сайта GeeksForGeeks, так как хотелось быстро придумать минимальный рабочий пример.

🤔 А знаете ли вы, что...
Python имеет множество фреймворков для веб-разработки, такие как Django и Flask.


1
55
1

Ответ:

Решено

Цель состоит в том, что если у меня есть набор, скажем, 10 000 Linkedlists, я смогу применить эту функцию суммирования к каждому LinkedList в векторизованном виде.

Эта цель недостижима с использованием JAX. Вы можете зарегистрировать свой класс как собственный Pytree, чтобы он работал с функциями JAX (см. Расширение pytrees), но это не означает, что вы можете векторизовать операцию над списком таких объектов.

Преобразования JAX, такие как vmap и jit, работают для данных, хранящихся с шаблоном структуры массивов (например, один объект LinkedList, содержащий массивы, которые представляют несколько пакетных связанных списков), а не шаблон массива структур (например, список из нескольких объектов LinkedList). .

Кроме того, используемый вами алгоритм, основанный на цикле while, несовместим с преобразованиями JAX (см. Резкие биты JAX: поток управления), а дерево узлов с динамическим размером не вписывается в ограничения статической формы JAX-программы.

Мне бы хотелось указать вам правильное направление, но я считаю, что вам нужно либо отказаться от использования JAX, либо отказаться от использования динамических связанных списков. Вы не сможете сделать и то, и другое.