Удаление плохих пикселей с помощью массива dask

1

У меня очень большой 4-мерный набор данных, где два последних измерения - это изображения, сделанные на детекторе изображения. Некоторые пиксели на этом детекторе не работают, и эти мертвые пиксели дают значение 0. Я хотел бы установить значение этих пикселей для медианы своих соседей в последующей обработке. Наборы данных варьируются от 8 ГБ до размера TB, поэтому я хотел бы использовать массив dask, так как я могу связать удаление мертвого пикселя вместе с другими этапами обработки.

Найти мертвые пиксели легко, но я не уверен, как лучше всего получить медиану соседей.

Минимальный пример:

import numpy as np
import dask.array as da

data = np.random.randint(10, 50, size=(10, 10, 20, 20))
data[:, :, 2, 7] = 0
data[:, :, 9, 3] = 0
dask_array = da.from_array(data, chunks=(5, 5, 5, 5))
dask_array_mean = dask_array.mean(axis=(0, 1))
dead_pixels = dask_array_mean == 0

# Some kind of processing

dask_array_without_dead_pixels = dask_array + dead_pixel_values_array

Поэтому мой вопрос: как мне получить dead_pixel_values_array? Или какой-нибудь другой умный способ удаления мертвых пикселей?

Теги:
dask

2 ответа

2

Вот реализация региональной медианы для массивов dask: https://dask-ndfilters.readthedocs.io/en/latest/dask_ndfilters.html#dask_ndfilters.median_filter

Если вам нужно что-то более общее, или вы не хотите устанавливать dask_ndfilters, вы должны прочитать на map_overlap который позволяет вам получать доступ к данным вокруг каждого фрагмента, который поступает из соседних блоков, и таким образом учитывает ваши вычисления.

0

Мне удалось это сделать, сначала переназначив размеры изображения (два последних) на один фрагмент, а затем используя map_blocks. map_overlay работал для небольших наборов данных, но для более крупных приложений использование памяти было намного больше, чем доступная память.

import numpy as np
import dask.array as da
import matplotlib.pyplot as plt

def remove_dead_pixels(data, dead_pixels):
    dif0 = np.roll(data, shift=1, axis=-2) * dead_pixels
    dif1 = np.roll(data, shift=-1, axis=-2) * dead_pixels
    dif2 = np.roll(data, shift=1, axis=-1) * dead_pixels
    dif3 = np.roll(data, shift=-1, axis=-1) * dead_pixels
    output_data = np.median(np.stack([dif0, dif1, dif2, dif3], axis=-1),
                            axis=-1, overwrite_input=True, keepdims=False)
    return output_data

# Making artificial data
data = np.random.randint(10, 50, size=(10, 10, 20, 20))
data[:, :, 2, 7] = 0
data[:, :, 9, 3] = 0
dask_array = da.from_array(data, chunks=(5, 5, 5, 5))
dask_array = dask_array.rechunk((5, 5, 20, 20))

# Finding dead pixels
dask_array_mean = dask_array.mean(axis=(0, 1))
dead_pixels = dask_array_mean == 0

# Getting replacement values
dead_pixel_values_array = da.map_blocks(
        remove_dead_pixels, dask_array, dead_pixels, dtype=dask_array.dtype,
        chunks=dask_array.chunks)
dask_array_without_dead_pixels = dask_array + dead_pixel_values_array

# Plotting result
fig, axarr = plt.subplots(1, 2, figsize=(10, 5))
axarr[0].imshow(dask_array.sum(axis=(0, 1)).compute())
axarr[1].imshow(dask_array_without_dead_pixels.sum(axis=(0, 1)).compute())
fig.tight_layout()
fig.savefig("image.jpg")

Изображение 174551

Ещё вопросы

Сообщество Overcoder
Наверх
Меню