Ускорить код Cython

1

Я написал код python, который управляет большим количеством данных, и, следовательно, это занимает много времени. Итак, я узнал, что Китон и я начали менять свой код.

В основном, все, что я делал, - это изменить объявления функций (имя типа cdef (аргументы с типом переменной)), объявить переменные cdef с его типом и объявить классы cdef. Я пишу все .pyx с eclipse, и я компилирую команду python setup.py build_ext --inplace и запускаю ее с помощью eclipse.

Моя проблема в том, что сравнение python с скоростью cython, нет никакой разницы.

Я запускаю команду cython -a <file> чтобы сгенерировать html файл, и есть много желтых строк.

Я не знаю, что я делаю что-то неправильно, я должен включить что-то еще, и я не знаю, как удалить эти желтые строки.

Я просто вставляю несколько строк кода, что часть, которую я хотел бы ускорить, и потому, что код очень длинный.


main.pyx

'''there are a lot of ndarray objects stored in a file and in this step I get each of them until there are no more items '''
cdef ReadWavePoints (WavePointManagement wavePointManagement, ColumnManagement columnManagement):
        cdef int runReadWavePoints

    wavePointManagement.OpenWavePointFileLoad(wavePointsFile)
    runReadWavePoints = 1

    while runReadWavePoints == 1:
        try:
            wavePointManagement.LoadWavePointFile()
            wavePointManagement.RoundCoordinates()
            wavePointManagement.SortWavePointList()
            GroupColumnsVoxels(wavePointManagement.GetWavePointList(), columnManagement)
        except:
            wavePointManagement.CloseWavePointFile()
            columnManagement.CloseWriteColumnFile()
            break

'''I check which points are in the same XYZ (voxel) and in the same XY (column)'''

cdef GroupColumnsVoxels (object wavePointList, ColumnManagement columnManagement):
    cdef int indexWavePointRef, indexWavePoint
    cdef int saved
    cdef double voxelValue
    cdef int sizeWavePointList

    sizeWavePointList = len(wavePointList)

    indexWavePointRef = 0

    while indexWavePointRef < sizeWavePointList - 1:
        saved = 0
        voxelValue = (wavePointList[indexWavePointRef]).GetValue()
        for indexWavePoint in xrange(indexWavePointRef + 1, len(wavePointList)):
            if (wavePointList[indexWavePointRef]).GetX() == (wavePointList[indexWavePoint]).GetX() and (wavePointList[indexWavePointRef]).GetY() == (wavePointList[indexWavePoint]).GetY():
                if (wavePointList[indexWavePointRef]).GetZ() == (wavePointList[indexWavePoint]).GetZ():
                    if voxelValue < (wavePointList[indexWavePoint]).GetValue():
                        voxelValue = (wavePointList[indexWavePoint]).GetValue()
                else:
                    saved = 1
                    CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
                    indexWavePointRef = indexWavePoint
                    if indexWavePointRef == sizeWavePointList - 1:
                        CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), (wavePointList[indexWavePointRef]).GetValue())
                    break
            else:
                saved = 1
                CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
                columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
                columnManagement.AddColumn(columnObject)
                MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ()) 
                indexWavePointRef = indexWavePoint
                break
        if saved == 0:
            CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
            indexWavePointRef = indexWavePoint
    columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
    columnManagement.AddColumn(columnObject)
    MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ())



'''I check if the data stored in a voxel is lower than the new one; if its the case, I store it'''  

cdef CheckVoxel (double X, double Y, double Z, double newValue):
    cdef object bandVoxel, structvalCheckVoxel, out_str
    cdef tuple valueCheckVoxel

    bandVoxel = datasetVoxels.GetRasterBand(int(math.floor(Z/0.3))+1)
    structvalCheckVoxel = bandVoxel.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
    valueCheckVoxel = struct.unpack('f', structvalCheckVoxel)

    if newValue > valueCheckVoxel[0]:
        out_str = struct.pack('f', newValue)
        bandVoxel.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_str)

'''I check if this point has the highest Z and I store this information'''    
cdef MaximumHeightColumn(double X, double Y, double newZ):
        cdef object bandMetricMaximumHeightColumn, structvalMaximumHeightColumn, out_strMaximumHeightColumn
    cdef tuple valueMaximumHeightColumn

    bandMetricMaximumHeightColumn = datasetMetrics.GetRasterBand(10)
    structvalMaximumHeightColumn = bandMetricMaximumHeightColumn.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
    valueMaximumHeightColumn = struct.unpack('f', structvalMaximumHeightColumn)

    if newZ > round(valueMaximumHeightColumn[0], 1):
        out_strMaximumHeightColumn = struct.pack('f', newZ)
        bandMetricMaximumHeightColumn.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_strMaximumHeightColumn)

WavePointManagement.pyx

'''this class serializes, rounds and sorts the points of each ndarray'''
import cPickle as pickle
import numpy as np
cimport numpy as np
import math

cdef class WavePointManagement(object):
    '''
    This class manages all the points extracted from the waveform
    '''
    cdef object fileObject, wavePointList
    __slots__ = ('wavePointList', 'fileObject')

    def __cinit__(self):
        '''
        Constructor
        '''

        self.fileObject = None
        self.wavePointList = np.array([])

    cdef object GetWavePointList(self):
        return self.wavePointList

    cdef void OpenWavePointFileLoad (self, object fileName):
        self.fileObject = file(fileName, 'rb')

    cdef void LoadWavePointFile (self):
        self.wavePointList = None
        self.wavePointList = pickle.load(self.fileObject)

    cdef void SortWavePointList (self):
        self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))

    cdef void RoundCoordinates (self):
        cdef int indexPointObject, sizeWavePointList

        for pointObject in self.GetWavePointList():
            pointObject.SetX(round(math.floor(pointObject.GetX()/0.25)*0.25, 2))
            pointObject.SetY(round(math.ceil(pointObject.GetY()/0.25)*0.25, 2))
            pointObject.SetZ(round(math.floor(pointObject.GetZ()/0.3)*0.3, 1))

    cdef void CloseWavePointFile(self):
        self.fileObject.close()

setup.py

from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext

import numpy

ext = Extension("main", ["main.pyx"], include_dirs = [numpy.get_include()])

setup (ext_modules=[ext], 
       cmdclass = {'build_ext' : build_ext}
       )

test_cython.py

'''this is the file I run with eclipse after compiling'''
from main import main

main()

Как я могу ускорить этот код?

Теги:
performance
cython

1 ответ

3
Лучший ответ

Ваш код перескакивает назад и вперед между использованием массивов numpy и списков. Таким образом, практически нет никакой разницы между кодом, который будет производить cython.

Следующий код создает список python, а ключевая функция также является чистой функцией python.

self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))

Вы захотите использовать ndarray.sort (или numpy.sort если вы не хотите сортировать на месте). Для этого вам также нужно будет изменить способ хранения объектов в массиве. То есть вам понадобится использовать структурированный массив. См. numpy.sort для примеров того, как сортировать структурированные массивы - особенно последние два примера на странице.

После того, как у вас есть данные, хранящиеся в массиве numpy, вам нужно сообщить cython о том, как данные хранятся в массиве. Это включает в себя предоставление информации о типе и размерах массива. На этой странице представлена дополнительная информация о том, как эффективно работать с массивами numpy.

Пример показа для создания и сортировки структурированных массивов:

import numpy as np
cimport numpy as np

DTYPE = [('name', 'S10'), ('height', np.float64), ('age', np.int32)]

cdef packed struct Person:
    char name[10]
    np.float64_t height
    np.int32_t age

ctypedef Person DTYPE_t

def create_array():
    values = [('Arthur', 1.8, 41), ('Lancelot', 1.9, 38),
              ('Galahad', 1.7, 38)]
    return np.array(values, dtype=DTYPE)

cpdef sort_by_age_then_height(np.ndarray[DTYPE_t, ndim=1] arr):
    arr.sort(order=['age', 'height'])  

Наконец, вам нужно будет преобразовать ваш код с помощью методов python в использование стандартных методов библиотеки c для дальнейшего ускорения. Ниже приведен пример использования RoundCoordinates. '' cpdef 'означает, что функция также подвергается python с помощью функции-обертки.

cimport cython
cimport numpy as np
from libc.math cimport floor, ceil, round

import numpy as np

DTYPE = [('x', np.float64), ('y', np.float64), ('z', np.float64)]

cdef packed struct Point3D:
    np.float64_t x, y, z

ctypedef Point3D DTYPE_t

# Caution should be used when turning the bounds check off as it can lead to undefined 
# behaviour if you use an invalid index.
@cython.boundscheck(False)
cpdef RoundCoordinates_cy(np.ndarray[DTYPE_t] pointlist):
    cdef int i
    cdef DTYPE_t point
    for i in range(len(pointlist)): # this line is optimised into a c loop
        point = pointlist[i] # creates a copy of the point
        point.x = round(floor(point.x/0.25)*2.5) / 10
        point.y = round(ceil(point.y/0.25)*2.5) / 10
        point.z = round(floor(point.z/0.3)*3) / 10
        pointlist[i] = point # overwrites the old point data with the new data

Наконец, прежде чем переписывать всю свою базу кода, вы должны просмотреть свой код, чтобы узнать, какие функции программа проводит большую часть своего времени, и оптимизировать эти функции, прежде чем беспокоиться об оптимизации других функций.

  • 0
    У меня возникла проблема с попыткой сделать то, что вы мне сказали, и у меня есть массив объектов, поэтому я не знаю, как отсортировать этот массив, используя пример, который вы мне прислали. Решением было бы создание нового массива, содержащего в каждом столбце значение каждого объекта атрибута, но это заняло бы много времени. Я также выполняю сериализацию массива numpy с помощью cPickle.dump и загружаю его с помощью cPickle.load. Я думаю, что в этом случае нет никаких проблем, потому что я сохраняю и загружаю массив numpy, не так ли? Я собираюсь сделать другие вещи, которые вы мне написали. большое спасибо
  • 0
    В соответствии с тем, что вы мне написали, я думаю, что я должен работать с массивами с разными столбцами вместо массивов объектов. Компиляция кода C У меня есть некоторые проблемы с объявлениями упакованных cdef структур, потому что в моем случае это класс, поэтому я думаю, что я попытаюсь написать весь код и классы функций в одном файле. Я не знаю, прав ли я, но я собираюсь попробовать
Показать ещё 5 комментариев

Ещё вопросы

Сообщество Overcoder
Наверх
Меню