Python的多进程与实战

写在前面

你们Python的多线程还是去死吧,跑起来比单线程还慢。

抱歉我说错了,是你们Official的Python解释器太沙雕了。

正文

1

既然我们喷了Python的多线程,那么我们首先要知道为什么它会被喷。

其实,Python的多线程是被一个叫做GIL的全局锁制约的,这个锁的存在,使得Python的多个线程真的是以这个时间片只有一个线程在转这样的过程执行的。这还没完,由于我是软件实现的全局锁,所以哪怕我的多个线程真的能分散到不同的核心上,他也只能按顺序轮流跑。好的,那么问题来了,既然我每次只能有一个线程在转,那么我的实际运行时间就是单线程运行时间+所有线程的轮转调度时间+每个变量线程间互斥锁处理的时间。显然,这肯定是比单线程跑还慢的。实际运行时也是如此,我出门吃个饭的时间多线程版脚本处理的量只比我出去冲杯咖啡的时间单线程版脚本处理的量多一点。

那么怎么办呢?很简单,多进程,这样可以确保每个进程都在不同的核心上被同时处理,只需要再额外处理好进程间数据共享即可。

2

下面就是介绍Python的多进程了。这里选择的是自带的multiprocessing库,入门来说足够使用了。
使用方法很简单,在需要的地方(例如文件头部)import multiprocessing即可开始使用。
通常来讲,我们需要import这些重要的东西。

第一个就是multiprocessing.Manager(),这个部分可以理解成是靠共享内存的方式共享了多进程同时使用的资源,当然这个同时是广义上的。Manager里面提供的类型例如队列或者管道是进程安全的,使用方式和常规版本一样,所以可以在多进程环境下放心使用。

第二个是multiprocessing.Process(),这个部分很好理解,就是子进程。创建的子进程会执行一个特定的函数,所以他的一个参数就是一个函数指针,另外的一个参数就是一个输入的函数对应的传入参数的tuple。

第三个是multiprocessing.Pool()。这个就是进程池。很好理解,就是限制你最大子进程数的一个类,当然也有很强大的进程管理的能力。用法很简单,通常会靠调用apply_async方法创建子进程,然后执行close方法表示我子进程创建完了,最后执行join方法进行监听,直到所有子进程结束。

实战

我们来给定一个场景,有一个Dataset D,size(D) > 1e5,D内部是三通道图像di,0\<=i\<size(D),dim(D) > 1280*720。其中部分图像是黑白的,我们需要过滤掉这些数据。

过滤图片的算法很简单,三通道形式读取图片,然后计算每个像素上的RGB三个值的标准差,对全图所有像素的标准差求个平均,0就是黑白图,当然过滤掉这种图的阈值往往会卡一下。

那么怎么写一个这样的程序呢?
直接贴代码好了。

#encoding: utf-8
from skimage.io import imread, imsave
from math import sqrt
import numpy as np
import sys
import os
from argparse import ArgumentParser
import multiprocessing
from multiprocessing import Manager
#images path
ROOT_DIR_PATH = "Data Root Path"
SOURCE_PATH = ROOT_DIR_PATH + "listfile"
# create shared resources/variables
manager = multiprocessing.Manager()
que_ = manager.Queue()
que_out = manager.Queue()
#calculate the average stddev
def calc(path):
    mat = imread(path).astype(np.float32)
    stddev = 0.0
    shp = mat.shape
    for i in range(shp[0]):
        for j in range(shp[1]):
            avg = float(mat[i][j][0] + mat[i][j][1] + mat[i][j][2]) / 3
            stddev += sqrt((avg - mat[i][j][0]) ** 2 + (avg - \
                        mat[i][j][1]) ** 2 + (avg - mat[i][j][2]) ** 2)
    stddev /= shp[0] * shp[1]
    return stddev < 1
#the function in subprocess
def run(dir_path, queue_in, queue_out):
    while queue_in.qsize() > 0:
        pi = queue_in.get()
        path = dir_path + pi.split()[0] # dir + filename
        queue_in.task_done()
        if (path.endswith("png") or path.endswith("jpg") or \
                path.endswith("jpeg")) == False:
            continue
        if calc(path) == False:
            queue_out.put(pi)
            print(piece)
            sys.stdout.flush()
#work function, you know
def main():
    parse = ArgumentParser()
    parse.add_argument("-n", action="store", type=int, \
                   help="number of processes to execute", dest="num", default=8)
    args = parse.parse_args()
    _write = open("output", "w")
    with open(SOURCE_PATH, "r") as f:
        files = f.readlines()
        for fil in files:
            que_.put(fil)
        f.close()
    pool = multiprocessing.Pool()
    for i in range(args.num):
        train_pool.apply_async(run, args=(ROOT_DIR_PATH, que_, que_out,))
    pool.close()
    pool.join()
    while que_out.empty() == False:
        piece = que_out.get()
        _write.write(piece)
        _write.flush()
        que_out.task_done()
    _write.close()
#start at here, main function, you know
if __name__ == "__main__":
    main()


发表评论

电子邮件地址不会被公开。

This site uses Akismet to reduce spam. Learn how your comment data is processed.