Я работаю с очень большими (несколько ГБ) двумерными квадратными массивами NumPy . Учитывая входной массив a
, для каждого элемента я хотел бы найти направление его крупнейшего соседнего соседа. Я использую предоставленный раздвижной вид окна, чтобы избежать создания ненужных копий:
# a is an L x L array of type np.float32
swv = sliding_window_view(a, (3, 3)) # (L-2) x (L-2) x 3 x 3
directions = swv.reshape(L-2, L-2, 9)[:,:,1::2].argmax(axis = 2).astype(np.uint8)
Однако вызов reshape здесь создает копию (L-2) x (L-2) x 9
вместо представления, что потребляет нежелательно большой кусок памяти. Есть ли способ выполнить эту операцию векторизованно, но с меньшим объемом памяти?
Обновлено: Многие ответы ориентированы на NumPy, который использует процессор (поскольку я изначально спрашивал об этом, чтобы упростить проблему). Будет ли оптимальная стратегия использования CuPy другой, то есть NumPy для графического процессора? Насколько я знаю, это делает использование Numba гораздо менее простым.
🤔 А знаете ли вы, что...
Python поддерживает многозадачность и многопоточность.
вызов reshape здесь создает копию (L-2) x (L-2) x 9 вместо представления
Это связано с тем, что две последние оси целевого массива не могут быть изменены вместе здесь, в Numpy. Действительно, это означало бы, что шаг последнего измерения будет варьироваться между элементами, что не поддерживается (и, конечно, никогда не будет, потому что это сделает многие операции очень медленными и намного более сложными). Шаги могут быть постоянными только для данной оси. Когда представление не может быть создано, Numpy выполняет дорогостоящее копирование.
Есть ли способ выполнить эту операцию векторизованно, но с меньшим объемом памяти?
Ключевым моментом в таком случае является выполнение операции по частям. Если входные данные большие, это может быть быстрее, чем вычисление окончательного массива за одну уникальную операцию из-за ошибок кэша ЦП и страниц. Действительно, временный массив можно повторно использовать в памяти. При этом куски не должны быть слишком маленькими, иначе накладные расходы на вызов функции Numpy и цикла CPython станут дорогими.
Если L
довольно большой, вы можете просто перебирать построчно. В противном случае вам придется вычислять куски строк. Вот пример:
swv = np.lib.stride_tricks.sliding_window_view(a, (3, 3)) # (L-2) x (L-2) x 3 x 3
out = np.empty((L-2, L-2), dtype=np.uint8)
k = 8
for i in range(0, L-2, k):
out[i:i+k] = swv[i:i+k,:,:,:].reshape(-1, L-2, 9)[:,:,1::2].argmax(axis = 2).astype(np.uint8)
Вот результаты производительности a = np.random.rand(L, L).astype(np.float32)
с L = 1000
на моей машине с i5-9600KF и Numpy 1.24.3:
Initial implementation:
- time: 51 ms
- memory overhead: O(L**2)
Proposed implementation:
- time: 43 ms
- memory overhead: O(L * k)
Вычисление здесь происходит быстрее для всех k<30. При k>=30 массивы слишком велики, чтобы вычисления были эффективными, и они занимают примерно то же время, что и ваши вычисления (на самом деле, предлагаемая реализация даже в этом случае немного быстрее). Мы также можем заключить, что циклы CPython не являются медленными, пока фрагменты достаточно велики, чтобы накладные расходы были небольшими по сравнению со временем вычислений. Вычисления также занимают меньше памяти. Единственным недостатком является то, что код больше. Бесплатного обеда не существует.
Обратите внимание, что разумным значением k
может быть max(int(512*1024 / (3*3*a.itemsize*(L-2)) + 0.5), 1)
. По этой формуле вычисления, если это возможно, должны занимать не более нескольких МБ ОЗУ. Если это невозможно, потому что k=1, тогда необходимо взять C*a.itemsize*(L-2)*3*3/1024**2
MiB, где C
— небольшая константа (обычно 2).
Вот расширенный тест с другими конкурентными реализациями:
nocomment's first implementation ("mine"):
- time: 19 ms
- memory overhead: O(L**2)
nocomment's second implementation ("mine4"):
- time: 13 ms
- memory overhead: O(L**2)
Native scalar code:
- time: 1.5 ms
- memory overhead: O(1)
ken's best implementation ("neighbor_argmax"):
- time: 0.38 ms
- memory overhead: O(1)
Optimized native SIMD code:
- time: 0.25 ms
- memory overhead: O(1)
Первая реализация nocomment быстрее, но требует значительно больше памяти (хотя и меньше, чем исходный код). Действительно, необходимо одновременно выделить как минимум 3 временных логических массива. Размер каждого логического массива составляет (L-2)**2
байт. Это означает, что необходимо выделить как минимум 3 * (L-2)**2
байт. Это значительно больше, чем C * k * L
(где C
— константа, которая должна быть от 30 до 50), пока k
остается небольшим, а L
относительно большим. Вторая реализация на моей машине работает быстрее, но для нее также потребуется больше памяти.
Реализация ken (лучшая последовательная) великолепна, поскольку Numba генерирует ассемблерный код, используя инструкции SIMD, и его использование памяти очень незначительно (как и собственные коды). Он не так хорош, как оптимизированный нативный код, но довольно близок к этому. Я думаю, что основным недостатком является значительное время компиляции (800 мс платится только один раз при самом первом вызове).
Кроме того, можно отметить, что Numpy намного медленнее того, что можно реализовать в собственном коде (например, в C/C++). Разрыв становится еще больше, когда его собственный код оптимизирован для использования модулей SIMD, доступных на всех основных процессорах. Собственный код требует <1 КБ дополнительной памяти. У кода Numpy нет шансов приблизиться к такой производительности. Оптимизированный SIMD-совместимый собственный код примерно в 34 раза быстрее, чем предлагаемая реализация Numpy, и в 200 раз быстрее, чем исходный код, а также использует еще меньше памяти!
В целом, мы видим, что существует огромный разрыв между собственными/обработанными кодами и кодами, использующими только Numpy, как с точки зрения использования памяти, так и с точки зрения скорости.
Несколько решений, которые занимают гораздо меньше памяти, чем исходное, и работают быстрее. Память можно еще больше уменьшить, разбивая на части, как это сделал Жером. Умеренное остроумие L = 1000
:
Memory:
59.76 bytes/element original
3.06 bytes/element mine
9.00 bytes/element mine4
8.96 bytes/element mine5
Speed:
9.8 ± 0.3 ms mine
11.0 ± 0.5 ms mine5
22.2 ± 1.7 ms mine4
99.1 ± 4.9 ms original
Python: 3.12.2 (main, Jun 12 2024, 09:13:57) [GCC 14.1.1 20240522]
NumPy: 1.26.4
Я получаю значения в четырех направлениях (вверх/влево/вправо/вниз). В mine
я сравниваю все шесть пар, сохраняю результаты сравнения в виде 6-битных чисел, а затем смотрю, какое направление означает каждое 6-битное значение. В mine4
и mine5
я отслеживаю максимумы.
import numpy as np
def original(a):
swv = np.lib.stride_tricks.sliding_window_view(a, (3, 3))
return swv.reshape(L-2, L-2, 9)[:,:,1::2].argmax(axis = 2).astype(np.uint8)
def mine(a):
L = a[1:-1, :-2]
R = a[1:-1, 2:]
U = a[:-2, 1:-1]
D = a[2:, 1:-1]
cmp = (U < L).view(np.uint8)
for (x, y) in (U, R), (U, D), (L, R), (L, D), (R, D):
cmp = (cmp << 1) | (x < y).view(np.uint8)
table = np.array([
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 3,
0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 2, 3,
1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 3, 0, 0, 0, 0,
1, 0, 0, 0, 2, 0, 0, 0, 1, 1, 0, 3, 2, 0, 2, 3
], dtype=np.uint8)
return table[cmp]
# How I got my magic table
# table = np.zeros(64)
# table[mine(a)] = original(a)
# print(repr(table.astype(np.uint8)))
def mine4(a):
U = a[:-2, 1:-1]
L = a[1:-1, :-2]
R = a[1:-1, 2:]
D = a[2:, 1:-1]
max = np.maximum(U, L)
dir = (L > U).view(np.uint8)
dir[R > max] = 2
max = np.maximum(max, R)
dir[D > max] = 3
return dir
def mine5(a):
U = a[:-2, 1:-1]
L = a[1:-1, :-2]
R = a[1:-1, 2:]
D = a[2:, 1:-1]
return np.where(
np.maximum(R, D) > np.maximum(U, L),
(D > R).view(np.uint8) + 2,
(L > U).view(np.uint8)
)
funcs = [original, mine, mine4, mine5]
from timeit import timeit
from statistics import mean, stdev
import sys
import random
import tracemalloc as tm
L = 1000
a = np.random.random((L, L)).astype(np.float32)
# Correctness
print('Correctness:')
expect = original(a)
for f in funcs:
print((f(a) == expect).all(), f.__name__)
# Memory
print('\nMemory:')
for f in funcs * 2:
tm.start()
f(a)
print(f'{tm.get_traced_memory()[1] / L**2 :5.2f} bytes/element ', f.__name__)
tm.stop()
# Speed
times = {f: [] for f in funcs}
def stats(f):
ts = [t * 1e3 for t in sorted(times[f])[:5]]
return f'{mean(ts):5.1f} ± {stdev(ts):3.1f} ms '
for _ in range(25):
random.shuffle(funcs)
for f in funcs:
t = timeit(lambda: f(a), number=1) / 1
times[f].append(t)
for f in sorted(funcs, key=stats):
print(stats(f), f.__name__)
print('\nPython:', sys.version)
print('NumPy: ', np.__version__)
Поскольку использование sliding_window_view
неэффективно для вашего случая, я предложу альтернативу с использованием Numba.
Во-первых, чтобы упростить реализацию, определите следующую альтернативу argmax.
from numba import njit
@njit
def argmax(*values):
"""argmax alternative that can take an arbitrary number of arguments.
Usage: argmax(0, 1, 3, 2) # 2
"""
max_arg = 0
max_value = values[0]
for i in range(1, len(values)):
value = values[i]
if value > max_value:
max_value = value
max_arg = i
return max_arg
Это стандартная функция argmax, за исключением того, что она принимает несколько скалярных аргументов вместо одного массива numpy.
Используя эту альтернативу argmax, вашу операцию можно легко реализовать повторно.
@njit(cache=True)
def neighbor_argmax(a):
height, width = a.shape[0] - 2, a.shape[1] - 2
out = np.empty((height, width), dtype=np.uint8)
for y in range(height):
for x in range(width):
# window: a[y:y + 3, x:x + 3]
# center: a[y + 1, x + 1]
out[y, x] = argmax(
a[y, x + 1], # up
a[y + 1, x], # left
a[y + 1, x + 2], # right
a[y + 2, x + 1], # down
)
return out
Для работы этой функции требуется всего несколько переменных, исключая входной и выходной буферы. Поэтому нам не нужно беспокоиться об объеме памяти.
Альтернативно вы можете использовать трафарет, утилиту для раздвижных окон для Numba.
С помощью stencil
вам нужно только определить ядро. Нумба позаботится обо всем остальном.
from numba import njit, stencil
@stencil
def kernel(window):
# window: window[-1:2, -1:2]
# center: window[0, 0]
return np.uint8( # Don't forget to cast to np.uint8.
argmax(
window[-1, 0], # up
window[0, -1], # left
window[0, 1], # right
window[1, 0], # down
)
)
@njit(cache=True)
def neighbor_argmax_stencil(a):
return kernel(a)[1:-1, 1:-1] # Slicing is not mandatory.
Если хотите, его также можно встроить.
@njit(cache=True)
def neighbor_argmax_stencil_inlined(a):
f = stencil(lambda w: np.uint8(argmax(w[-1, 0], w[0, -1], w[0, 1], w[1, 0])))
return f(a)[1:-1, 1:-1] # Slicing is not mandatory.
Однако stencil
очень ограничен по функциональности и не может полностью заменить sliding_window_view
.
Единственное отличие состоит в том, что нет возможности пропускать края.
Он всегда дополняется постоянным значением (по умолчанию 0).
То есть, если вы поставите матрицу (L, L)
, вы получите результат (L, L)
, а не (L-2, L-2)
.
Вот почему я вырезаю выходные данные приведенного выше кода, чтобы они соответствовали вашей реализации. Однако это может быть нежелательным поведением, поскольку оно нарушает непрерывность памяти. Вы можете копировать после нарезки, но имейте в виду, что это увеличит пиковое использование памяти.
Кроме того, следует отметить, что эти функции также можно легко адаптировать для многопоточности. Подробную информацию см. в эталонном коде ниже.
Вот эталон.
import math
import timeit
import numpy as np
from numba import njit, prange, stencil
from numpy.lib.stride_tricks import sliding_window_view
def baseline(a):
L = a.shape[0]
swv = sliding_window_view(a, (3, 3)) # (L-2) x (L-2) x 3 x 3
directions = swv.reshape(L - 2, L - 2, 9)[:, :, 1::2].argmax(axis=2).astype(np.uint8)
return directions
@njit
def argmax(*values):
"""argmax alternative that can accept an arbitrary number of arguments.
Usage: argmax(0, 1, 3, 2) # 2
"""
max_arg = 0
max_value = values[0]
for i in range(1, len(values)):
value = values[i]
if value > max_value:
max_value = value
max_arg = i
return max_arg
@njit(cache=True)
def neighbor_argmax(a):
height, width = a.shape[0] - 2, a.shape[1] - 2
out = np.empty((height, width), dtype=np.uint8)
for y in range(height):
for x in range(width):
# window: a[y:y + 3, x:x + 3]
# center: a[y + 1, x + 1]
out[y, x] = argmax(
a[y, x + 1], # up
a[y + 1, x], # left
a[y + 1, x + 2], # right
a[y + 2, x + 1], # down
)
return out
@njit(cache=True, parallel=True) # Add parallel=True.
def neighbor_argmax_mt(a):
height, width = a.shape[0] - 2, a.shape[1] - 2
out = np.empty((height, width), dtype=np.uint8)
for y in prange(height): # Change this to prange.
for x in range(width):
# window: a[y:y + 3, x:x + 3]
# center: a[y + 1, x + 1]
out[y, x] = argmax(
a[y, x + 1], # up
a[y + 1, x], # left
a[y + 1, x + 2], # right
a[y + 2, x + 1], # down
)
return out
@stencil
def kernel(window):
# window: window[-1:2, -1:2]
# center: window[0, 0]
return np.uint8( # Don't forget to cast to np.uint8.
argmax(
window[-1, 0], # up
window[0, -1], # left
window[0, 1], # right
window[1, 0], # down
)
)
@njit(cache=True)
def neighbor_argmax_stencil(a):
return kernel(a)[1:-1, 1:-1] # Slicing is not mandatory.
@njit(cache=True)
def neighbor_argmax_stencil_with_copy(a):
return kernel(a)[1:-1, 1:-1].copy() # Slicing is not mandatory.
@njit(cache=True, parallel=True)
def neighbor_argmax_stencil_mt(a):
return kernel(a)[1:-1, 1:-1] # Slicing is not mandatory.
@njit(cache=True)
def neighbor_argmax_stencil_inlined(a):
f = stencil(lambda w: np.uint8(argmax(w[-1, 0], w[0, -1], w[0, 1], w[1, 0])))
return f(a)[1:-1, 1:-1] # Slicing is not mandatory.
def benchmark():
size = 2000 # Total nbytes (in MB) for a.
n = math.ceil(math.sqrt(size * (10 ** 6) / 4))
rng = np.random.default_rng(0)
a = rng.random(size=(n, n), dtype=np.float32)
print(f"{a.shape=}, {a.nbytes=:,}")
expected = baseline(a)
# expected = neighbor_argmax_mt(a)
assert expected.shape == (n - 2, n - 2) and expected.dtype == np.uint8
candidates = [
baseline,
neighbor_argmax,
neighbor_argmax_mt,
neighbor_argmax_stencil,
neighbor_argmax_stencil_mt,
neighbor_argmax_stencil_with_copy,
neighbor_argmax_stencil_inlined,
]
name_len = max(len(f.__name__) for f in candidates)
for f in candidates:
assert np.array_equal(expected, f(a)), f.__name__
t = timeit.repeat(lambda: f(a), repeat=3, number=1)
print(f"{f.__name__:{name_len}} : {min(t)}")
if __name__ == "__main__":
benchmark()
Результат:
a.shape=(22361, 22361), a.nbytes=2,000,057,284
baseline : 24.971996600041166
neighbor_argmax : 0.1917789001017809
neighbor_argmax_mt : 0.11929619999136776
neighbor_argmax_stencil : 0.2940085999434814
neighbor_argmax_stencil_mt : 0.17756330000702292
neighbor_argmax_stencil_with_copy : 0.46573049994185567
neighbor_argmax_stencil_inlined : 0.29338629997801036
Я думаю, этих результатов достаточно, чтобы вы подумали о том, чтобы попробовать Numba :)