ускорение класса с помощью массива с использованием Cython

1

У меня есть следующие коды:

class _Particles:
    def __init__(self, num_particle, dim, fun, lower_bound, upper_bound):
        self.lower_bound = lower_bound   # np.array of shape (dim,)
        self.upper_bound = upper_bound   # np.array of shape (dim,)
        self.num_particle = num_particle   # a scalar
        self.dim = dim   # dimension, a scalar
        self.fun = fun   # a function

        self.pos = np.empty((num_particle,dim))
        self.val = np.empty(num_particle)
        self.randomize()


    def randomize(self):
        self.pos = np.random.rand(self.num_particle, self.dim)*(self.upper_bound\
                -self.lower_bound)+self.lower_bound
        self.val = self.fun(np.transpose(self.pos))
        self.best_idx = np.argmin(self.val)
        self.best_val = self.val[self.best_idx]
        self.best_pos = self.pos[self.best_idx]


    def move(self, displacement, idx='all', check_bound=True):
        if idx is 'all':
            self.pos += displacement
        elif isinstance(idx,(tuple,list,np.ndarray)):
            self.pos[idx] += displacement
        else:
            raise TypeError('Check the type of idx!',type(idx))

        self.pos = np.maximum(self.pos, self.lower_bound[np.newaxis,:])
        self.pos = np.minimum(self.pos, self.upper_bound[np.newaxis,:])
        self.val = self.fun(np.transpose(self.pos))
        self.best_idx = np.argmin(self.val)
        self.best_val = self.val[self.best_idx]
        self.best_pos = self.pos[self.best_idx]

Я хочу проверить, могу ли я ускорить вышеуказанный код, и я думаю об использовании cython, но я не уверен, возможно ли это, поскольку он в основном использует массив numpy, а большинство исполнений выполняются с помощью векторизации. Я пробую что-то вроде:

# the .pyx file that will be compiled
cdef class _Particles(object):
    cdef int num_particle
    cdef int dim
    cdef fun
    cdef np.ndarray lower_bound
    cdef np.ndarray upper_bound
    cdef np.ndarray pos
    cdef np.ndarray val
    cdef int best_idx
    cdef double best_val
    cdef np.ndarray[np.float64_t, ndim=1] best_pos

    def __init__(self, int num_particle, int dim, fun,
                 np.ndarray lower_bound, np.ndarray upper_bound):
        self.num_particle = num_particle
        self.dim = dim
        self.fun = fun
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound

        self.pos = np.empty((num_particle,dim))
        self.val = np.empty(num_particle)
        self.randomize()

    def randomize(self):
        self.pos = npr.rand(self.num_particle,self.dim)*(self.upper_bound\
                -self.lower_bound)+self.lower_bound

        self.val = self.fun(np.transpose(self.pos))
        self.best_idx = np.argmin(self.val)
        self.best_val = self.val[self.best_idx]
        self.best_pos = self.pos[self.best_idx]

Это быстрее, но только немного, что и ожидалось, поскольку это по-прежнему в основном код python. Итак, есть ли способы ускорить вышеуказанный код с помощью cython (или указать мне на некоторые другие полностью методы)? В частности, как ускорить такие коды, как self.fun(self.pos), np.argmin(self.val)?

Благодарю.

  • 0
    Я бы посоветовал вам определить любые пользовательские имена, которые есть в вашем коде, например, npr (я думаю, это numpy.random ).
Теги:
numpy
vectorization
cython

1 ответ

3
Лучший ответ

На самом деле, я не боюсь оптимизировать приведенный выше код. Чтобы argmin я предлагаю вам получить (или иначе скомпилировать себя) NumPy с поддержкой многопоточности (или вы можете повторно реализовать несколько многопоточных argmin).

Что касается Cython, вы получаете реальную сделку, когда начинаете использовать типы C, но это то, что я не видел бы большого улучшения с кодом, который вы опубликовали. Это, в основном, клей-код, там не было никакого хруста.

Я бы ожидал, что в функции fun произойдет свертка чисел, и это, вероятно, единственное место, где фактическая ручная оптимизация может иметь значение, если это не так просто для векторизации (читайте: есть или for ручного ручного управления). Затем я начну с numba, что намного упрощает ускорение вашего кода, если оно работает. Если это не так, вероятно, стоит начать Cython.

Ещё вопросы

Сообщество Overcoder
Наверх
Меню