В MATLAB легко найти индексы значений, которые удовлетворяют конкретному условию:
>> a = [1,2,3,1,2,3,1,2,3];
>> find(a > 2) % find the indecies where this condition is true
[3, 6, 9] % (MATLAB uses 1-based indexing)
>> a(find(a > 2)) % get the values at those locations
[3, 3, 3]
Каким будет лучший способ сделать это в Python?
До сих пор я придумал следующее. Чтобы просто получить значения:
>>> a = [1,2,3,1,2,3,1,2,3]
>>> [val for val in a if val > 2]
[3, 3, 3]
Но если я хочу, чтобы индекс каждого из этих значений немного сложнее:
>>> a = [1,2,3,1,2,3,1,2,3]
>>> inds = [i for (i, val) in enumerate(a) if val > 2]
>>> inds
[2, 5, 8]
>>> [val for (i, val) in enumerate(a) if i in inds]
[3, 3, 3]
Есть ли лучший способ сделать это в Python, особенно для произвольных условий (не только "val > 2" )?
Я нашел функции, эквивалентные MATLAB 'find' в NumPy, но в настоящее время у меня нет доступа к этим библиотекам.
Вы можете сделать функцию, которая принимает вызываемый параметр, который будет использоваться в части условия вашего понимания списка. Затем вы можете использовать lambda или другой объект функции, чтобы передать произвольное условие:
def indices(a, func):
return [i for (i, val) in enumerate(a) if func(val)]
a = [1, 2, 3, 1, 2, 3, 1, 2, 3]
inds = indices(a, lambda x: x > 2)
>>> inds
[2, 5, 8]
Это немного ближе к вашему примеру Matlab, без необходимости загружать все numpy.
inds = [i for (i, val) in enumerate(a) if val > 2]
что является решением в одну строку.
в numpy у вас есть where
:
>> import numpy as np
>> x = np.random.randint(0, 20, 10)
>> x
array([14, 13, 1, 15, 8, 0, 17, 11, 19, 13])
>> np.where(x > 10)
(array([0, 1, 3, 6, 7, 8, 9], dtype=int64),)
x[x>3]
вместо np.where(x>3)
) (Не то, чтобы с where
то что-то не так! Прямая индексация может быть просто более знакомой формой для людей, знакомых с Matlab.)
Или используйте ненулевую функцию numpy:
import numpy as np
a = np.array([1,2,3,4,5])
inds = np.nonzero(a>2)
a[inds]
array([3, 4, 5])
Почему бы просто не использовать это:
[i for i in range(len(a)) if a[i] > 2]
или для произвольных условий, определите функцию f
для вашего условия и выполните:
[i for i in range(len(a)) if f(a[i])]
Я пытался найти быстрый способ сделать эту точную вещь, и вот что я наткнулся (использует numpy для быстрого сравнения векторов):
a_bool = numpy.array(a) > 2
inds = [i for (i, val) in enumerate(a_bool) if val]
Оказывается, это намного быстрее, чем:
inds = [i for (i, val) in enumerate(a) if val > 2]
Похоже, что Python быстрее сравнивается, когда выполняется в массиве numpy, и/или быстрее при составлении списков при проверке правды, а не сравнении.
Edit:
Я пересматривал свой код, и я наткнулся на, возможно, меньше интенсивный объем памяти, немного быстрее и супер-сжатый способ сделать это в одной строке:
inds = np.arange( len(a) )[ a < 2 ]
Чтобы получить значения с произвольными условиями, вы можете использовать filter()
с помощью лямбда-функции:
>>> a = [1,2,3,1,2,3,1,2,3]
>>> filter(lambda x: x > 2, a)
[3, 3, 3]
Одним из возможных способов получения индексов будет использование enumerate()
для построения кортежа с индексом и значениями, а затем фильтрация:
>>> a = [1,2,3,1,2,3,1,2,3]
>>> aind = tuple(enumerate(a))
>>> print aind
((0, 1), (1, 2), (2, 3), (3, 1), (4, 2), (5, 3), (6, 1), (7, 2), (8, 3))
>>> filter(lambda x: x[1] > 2, aind)
((2, 3), (5, 3), (8, 3))
filter
, но использование списочных представлений предпочтительнее и более оптимизировано.
Думаю, я нашел одну быструю и простую замену. BTW Я чувствовал, что функция np.where() не очень удовлетворительна, в каком-то смысле она содержит раздражающую строку нулевого элемента.
import matplotlib.mlab as mlab
a = np.random.randn(1,5)
print a
>> [[ 1.36406736 1.45217257 -0.06896245 0.98429727 -0.59281957]]
idx = mlab.find(a<0)
print idx
type(idx)
>> [2 4]
>> np.ndarray
Бест, Da
Подпрограмма numpy
, обычно используемая для этого приложения, numpy.where()
; хотя я считаю, что он работает так же, как numpy.nonzero()
.
import numpy
a = numpy.array([1,2,3,4,5])
inds = numpy.where(a>2)
Чтобы получить значения, вы можете либо сохранить индексы и срез с их помощью:
a[inds]
или вы можете передать массив как необязательный параметр:
numpy.where(a>2, a)
или несколько массивов:
b = numpy.array([11,22,33,44,55])
numpy.where(a>2, a, b)
Код поиска Matlab имеет два аргумента. Код John учитывает первый аргумент, но не второй. Например, если вы хотите знать, где в индексе выполняется условие: функция Mtlab будет:
find(x>2,1)
Используя код Джона, все, что вам нужно сделать, это добавить [x] в конец функции индексов, где x - номер индекса, который вы ищете.
def indices(a, func):
return [i for (i, val) in enumerate(a) if func(val)]
a = [1, 2, 3, 1, 2, 3, 1, 2, 3]
inds = indices(a, lambda x: x > 2)[0] #[0] being the 2nd matlab argument
который возвращает → > 2, первый индекс должен превышать 2.
[a[i] for i in inds]
, что немного проще.