суммировать массив 3x3 в данной точке с другой матрицей, поддерживающей границы

1

предположим, что у меня есть этот 2d-массив A:

[[0,0,0,0],
 [0,0,0,0],
 [0,0,0,0],
 [0,0,0,4]]

и я хочу суммировать B:

[[1,2,3]
 [4,5,6]
 [7,8,9]]

с центром на A [0] [0], поэтому результатом будет:

array_sum(A,B,0,0) =
[[5,6,0,4],
 [8,9,0,0],
 [0,0,0,0],
 [2,0,0,5]]

Я думал, что должен сделать функцию, которая сравнивает, если она находится на границе, а затем корректирует индекс для этого:

def array_sum(A,B,i,f):
   ...
   if i == 0 and j == 0:
      A[-1][-1] = A[-1][-1]+B[0][0]
      ...

   else:
      A[i-1][j-1] = A[i][j]+B[0][0]
      A[i][j] = A[i][j]+B[1][1]
      A[i+1][j+1] = A[i][j]+B[2][2]
      ...

но я не знаю, есть ли лучший способ сделать это, я читал о вещании или, возможно, использовал convolute для этого, но я не уверен, есть ли лучший способ сделать это.

  • 0
    Приспосабливание к границе не плохая идея, но есть лучшие способы сделать это, чем иметь 9 операторов if.
  • 0
    размеры B всегда будут нечетными?
Показать ещё 4 комментария
Теги:
numpy

2 ответа

1

Предполагая, что B.shape - это все нечетные числа, вы можете использовать np.indices, манипулировать ими, чтобы указать, где хотите, и использовать np.add.at

def array_sum(A, B, loc = (0, 0)):
    A_ = A.copy()
    ix = np.indices(B.shape)
    new_loc = np.array(loc) - np.array(B.shape) // 2
    new_ix = np.mod(ix + new_loc[:, None, None], 
                    np.array(A.shape)[:, None, None])
    np.add.at(A_, tuple(new_ix), B)
    return A_

Тестирование:

array_sum(A, B)
Out:
array([[ 5.,  6.,  0.,  4.],
       [ 8.,  9.,  0.,  7.],
       [ 0.,  0.,  0.,  0.],
       [ 2.,  3.,  0.,  5.]])
0

Как правило, индексация фрагментов быстрее (~ 2x), чем причудливая индексация. Это кажется верным даже для небольшого примера в OP. Нижняя сторона: код немного сложнее.

import numpy as np
from numpy import s_ as _
from itertools import product, starmap

def wrapsl1d(N, n, c):
    # check in 1D whether a patch of size n centered at c in a vector
    # of length N fits or has to be wrapped around
    # return appropriate slice objects for both vector and patch
    assert n <= N
    l = (c - n//2) % N
    h = l + n
    # return list of pairs (index into A, index into patch)
    # 2 pairs if we wrap around, otherwise 1 pair
    return [_[l:h, :]] if h <= N else [_[l:, :N-l], _[:h-N, n+N-h:]]

def use_slices(A, patch, center=(0, 0)):
    slAptch = product(*map(wrapsl1d, A.shape, patch.shape, center))
    # the product now has elements [(idx0A, idx0ptch), (idx1A, idx1ptch)]
    # transpose them:
    slAptch = starmap(zip, slAptch)
    out = A.copy()
    for sa, sp in slAptch:
        out[sa] += patch[sp]
    return out

Ещё вопросы

Сообщество Overcoder
Наверх
Меню