У меня очень большой 4-мерный набор данных, где два последних измерения - это изображения, сделанные на детекторе изображения. Некоторые пиксели на этом детекторе не работают, и эти мертвые пиксели дают значение 0. Я хотел бы установить значение этих пикселей для медианы своих соседей в последующей обработке. Наборы данных варьируются от 8 ГБ до размера TB, поэтому я хотел бы использовать массив dask, так как я могу связать удаление мертвого пикселя вместе с другими этапами обработки.
Найти мертвые пиксели легко, но я не уверен, как лучше всего получить медиану соседей.
Минимальный пример:
import numpy as np
import dask.array as da
data = np.random.randint(10, 50, size=(10, 10, 20, 20))
data[:, :, 2, 7] = 0
data[:, :, 9, 3] = 0
dask_array = da.from_array(data, chunks=(5, 5, 5, 5))
dask_array_mean = dask_array.mean(axis=(0, 1))
dead_pixels = dask_array_mean == 0
# Some kind of processing
dask_array_without_dead_pixels = dask_array + dead_pixel_values_array
Поэтому мой вопрос: как мне получить dead_pixel_values_array
? Или какой-нибудь другой умный способ удаления мертвых пикселей?
Вот реализация региональной медианы для массивов dask: https://dask-ndfilters.readthedocs.io/en/latest/dask_ndfilters.html#dask_ndfilters.median_filter
Если вам нужно что-то более общее, или вы не хотите устанавливать dask_ndfilters
, вы должны прочитать на map_overlap
который позволяет вам получать доступ к данным вокруг каждого фрагмента, который поступает из соседних блоков, и таким образом учитывает ваши вычисления.
Мне удалось это сделать, сначала переназначив размеры изображения (два последних) на один фрагмент, а затем используя map_blocks
. map_overlay
работал для небольших наборов данных, но для более крупных приложений использование памяти было намного больше, чем доступная память.
import numpy as np
import dask.array as da
import matplotlib.pyplot as plt
def remove_dead_pixels(data, dead_pixels):
dif0 = np.roll(data, shift=1, axis=-2) * dead_pixels
dif1 = np.roll(data, shift=-1, axis=-2) * dead_pixels
dif2 = np.roll(data, shift=1, axis=-1) * dead_pixels
dif3 = np.roll(data, shift=-1, axis=-1) * dead_pixels
output_data = np.median(np.stack([dif0, dif1, dif2, dif3], axis=-1),
axis=-1, overwrite_input=True, keepdims=False)
return output_data
# Making artificial data
data = np.random.randint(10, 50, size=(10, 10, 20, 20))
data[:, :, 2, 7] = 0
data[:, :, 9, 3] = 0
dask_array = da.from_array(data, chunks=(5, 5, 5, 5))
dask_array = dask_array.rechunk((5, 5, 20, 20))
# Finding dead pixels
dask_array_mean = dask_array.mean(axis=(0, 1))
dead_pixels = dask_array_mean == 0
# Getting replacement values
dead_pixel_values_array = da.map_blocks(
remove_dead_pixels, dask_array, dead_pixels, dtype=dask_array.dtype,
chunks=dask_array.chunks)
dask_array_without_dead_pixels = dask_array + dead_pixel_values_array
# Plotting result
fig, axarr = plt.subplots(1, 2, figsize=(10, 5))
axarr[0].imshow(dask_array.sum(axis=(0, 1)).compute())
axarr[1].imshow(dask_array_without_dead_pixels.sum(axis=(0, 1)).compute())
fig.tight_layout()
fig.savefig("image.jpg")