предположим, что у меня есть этот 2d-массив A:
[[0,0,0,0],
[0,0,0,0],
[0,0,0,0],
[0,0,0,4]]
и я хочу суммировать B:
[[1,2,3]
[4,5,6]
[7,8,9]]
с центром на A [0] [0], поэтому результатом будет:
array_sum(A,B,0,0) =
[[5,6,0,4],
[8,9,0,0],
[0,0,0,0],
[2,0,0,5]]
Я думал, что должен сделать функцию, которая сравнивает, если она находится на границе, а затем корректирует индекс для этого:
def array_sum(A,B,i,f):
...
if i == 0 and j == 0:
A[-1][-1] = A[-1][-1]+B[0][0]
...
else:
A[i-1][j-1] = A[i][j]+B[0][0]
A[i][j] = A[i][j]+B[1][1]
A[i+1][j+1] = A[i][j]+B[2][2]
...
но я не знаю, есть ли лучший способ сделать это, я читал о вещании или, возможно, использовал convolute для этого, но я не уверен, есть ли лучший способ сделать это.
Предполагая, что B.shape
- это все нечетные числа, вы можете использовать np.indices
, манипулировать ими, чтобы указать, где хотите, и использовать np.add.at
def array_sum(A, B, loc = (0, 0)):
A_ = A.copy()
ix = np.indices(B.shape)
new_loc = np.array(loc) - np.array(B.shape) // 2
new_ix = np.mod(ix + new_loc[:, None, None],
np.array(A.shape)[:, None, None])
np.add.at(A_, tuple(new_ix), B)
return A_
Тестирование:
array_sum(A, B)
Out:
array([[ 5., 6., 0., 4.],
[ 8., 9., 0., 7.],
[ 0., 0., 0., 0.],
[ 2., 3., 0., 5.]])
Как правило, индексация фрагментов быстрее (~ 2x), чем причудливая индексация. Это кажется верным даже для небольшого примера в OP. Нижняя сторона: код немного сложнее.
import numpy as np
from numpy import s_ as _
from itertools import product, starmap
def wrapsl1d(N, n, c):
# check in 1D whether a patch of size n centered at c in a vector
# of length N fits or has to be wrapped around
# return appropriate slice objects for both vector and patch
assert n <= N
l = (c - n//2) % N
h = l + n
# return list of pairs (index into A, index into patch)
# 2 pairs if we wrap around, otherwise 1 pair
return [_[l:h, :]] if h <= N else [_[l:, :N-l], _[:h-N, n+N-h:]]
def use_slices(A, patch, center=(0, 0)):
slAptch = product(*map(wrapsl1d, A.shape, patch.shape, center))
# the product now has elements [(idx0A, idx0ptch), (idx1A, idx1ptch)]
# transpose them:
slAptch = starmap(zip, slAptch)
out = A.copy()
for sa, sp in slAptch:
out[sa] += patch[sp]
return out
B
всегда будут нечетными?