如何让PyTorch与C++联动(1.x版)

ECCV投稿结束一个月了,之前一个研究周期点了一下新的技能点,就是怎么写PyTorch的C++ Extension。网上的教程多是旧体系的写法,我试着摸索了一下新版本该怎么写,现在整理一下留个备份。

原理

Python这门语言大家公认很简单,但是有一些很致命的问题,最典型的就是部分过程的执行效率不太够,例如循环。有的算法套上多层循环,在数据量太大的时候就非常慢

众所周知,Python和C/C++是可以进行联合编程的。基础思路如下,首先,实现该算法的C/C++版本;然后利用Python.h中定义的PyObject,实现接口代码与前端的Python Code建立关联;最后编写Python代码,例如说利用setuptools实现一个编译的配置,执行这段代码自动完成编译流程。Python便可以借由此方法完成提速。

同理,对于PyTorch,我们也可以采用类似的办法,将需要的算法封装成动态链接库,然后在Python中调用,提速原有的利用Python实现算法。

流程

首先,我们要能够写出来该算法的CPP版本,例如说,有一个复杂度为O(n^3)的算法,传入多个变量返回一个数组,数组的指针借由参数传入,我们可以按照常例将这个算法写出来,就跟打ICPC时写题一样。

1
2
3
4
5
6
7
8
9
10
void foo(static float * mat, static float * ans) {
// blabla...
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
for (int k = 0; k < n; k++) {
//ops;
}
}
}
}

在这里需要注意,由于实现的算法通常需要保证可导以便顺利地执行前向传播和反向传播的运算,所以在实现时,我们通常会同时实现一个算法的前向传播和反向传播运算,出于习惯会使用forward/backward后缀进行表示,我们会定义foo_forward_kernel/foo_backward_kernel来表示核心运算函数。

第二步,借助torch/extension.h中封装好的类型at::Tensor,构建对接PyTorch的函数,将传入参数通过.data\<T>()的方式显示转换成C++ Type,然后调用第一步写好的函数去执行。以foo_forward为例,我们一般会表示成如下的样子。

1
2
3
4
5
6
7
8
at::Tensor foo_forward(at::Tensor Mat) {
// blabla...
at::Tensor ans;
// blabla...
foo_forward_kernel(Mat.data<float>(), ans.dat<float>());
// blabla...
return ans;
}

第三步,编写基于setuptools的脚本,例如命名叫做setup.py,然后执行即可。示例代码如下。

1
2
3
4
5
6
7
8
9
from setuptools import find_packages, setup
setup(
name="cfoo",
version="1.14.514",
description="example code",
packages=find_packages(exclude=("configs", "tests",)),
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)

等等,还没完,你们不觉得直接这么使用forward/backward函数很不方便么?幸运的是,PyTorch内建的Module和Function可以轻松地封装任意运算的一对前向与反向操作,在执行过程中可以直接调用,并且自动完成反向传播的过程。这部分不做示例,但后面的样例代码会提到。

样例

在这里我们使用Tang et al.提出的Proposal Cluster Learning Loss Function进行示例。

考虑到方便在GPU上运行,我们就随便写一点CUDA代码啦。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
// pcl_loss_cuda.cu
#include <stdio.h>
#include <vector>
#include <math.h>
#include <float.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

#include "cuda/vision.h"

#define DIVUP(m, n) ((m) / (m) + ((m) % (n) > 0))

#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)

// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)

__global__ void PCLLossesForward(const int nthreads, const float* bottom_data,
const float* labels, const float* cls_loss_weights, const float* pc_labels,
const float* pc_probs, const float* img_cls_loss_weights,
const float* im_labels, const int batch_size, const int num_positive, float* top_data)
{
CUDA_KERNEL_LOOP(index, nthreads)
{
top_data[index] = 0;
if (im_labels[index] != 0) {
if (index == 0) {
for (int i = 0; i < batch_size; i++) {
if (labels[i] == 0) {
top_data[index] -= cls_loss_weights[i] * log(bottom_data[i * nthreads + index]);
}
}
}
else {
for (int i = 0; i < num_positive; i++) {
if (pc_labels[i] == index) {
top_data[index] -= img_cls_loss_weights[i] * log(pc_probs[i]);
}
}
}
}
}
}

int PCLLossesForwardLaucher(
const float* bottom_data, const float* labels, const float* cls_loss_weights,
const float* pc_labels, const float* pc_probs, const float* img_cls_loss_weights,
const float* im_labels, const int batch_size, const int channels,
const int num_positive, float* top_data, cudaStream_t stream)
{
const int kThreadsPerBlock = 4;
cudaError_t err;

PCLLossesForward<<<(channels + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(
channels, bottom_data, labels, cls_loss_weights, pc_labels, pc_probs, img_cls_loss_weights,
im_labels, batch_size, num_positive, top_data);

err = cudaGetLastError();
if(cudaSuccess != err)
{
fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
exit( -1 );
}

return 1;
}


__global__ void PCLLossesBackward(const int nthreads, const float* prob_data,
const float* labels, const float* cls_loss_weights, const float* gt_assignment,
const float* pc_labels, const float* pc_probs, const float* pc_count,
const float* img_cls_loss_weights, const float* im_labels, const int channels,
float* bottom_diff) {
CUDA_1D_KERNEL_LOOP(index, nthreads)
{

int i = index / channels;
int c = index % channels;
bottom_diff[index] = 0;

if (im_labels[c] != 0) {
if (c == 0) {
if (labels[i] == 0) {
bottom_diff[index] = -cls_loss_weights[i] / prob_data[index];
}
}
else {
if (labels[i] == c) {
int pc_index = gt_assignment[i];
if (c != pc_labels[pc_index]) {
printf("labels mismatch.\n");
}
bottom_diff[index] = -img_cls_loss_weights[pc_index]
/ (pc_count[pc_index] * pc_probs[pc_index]);
}
}
}
}
}

int PCLLossesBackwardLaucher(const float* top_diff, const float* prob_data,
const float* labels, const float* cls_loss_weights, const float* gt_assignment,
const float* pc_labels, const float* pc_probs, const float* pc_count,
const float* img_cls_loss_weights, const float* im_labels, const int batch_size,
const int channels, float* bottom_diff, cudaStream_t stream)
{
const int kThreadsPerBlock = 16;
int output_size = batch_size * channels;
cudaError_t err;

PCLLossesBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(
output_size, prob_data, labels, cls_loss_weights, gt_assignment, pc_labels, pc_probs, pc_count,
img_cls_loss_weights, im_labels, channels, bottom_diff);

err = cudaGetLastError();
if(cudaSuccess != err)
{
fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
exit( -1 );
}

return 1;
}

at::Tensor pcl_losses_forward_cuda(const at::Tensor& pcl_probs,
const at::Tensor& labels,
const at::Tensor& cls_loss_weights,
const at::Tensor& pc_labels,
const at::Tensor& pc_probs,
const at::Tensor& img_cls_loss_weights,
const at::Tensor& im_labels) {
int batch_size = pcl_probs.size(0);
int channels = pcl_probs.size(1);
int num_positive = pc_labels.size(1);
at::Tensor output = at::empty({1, channels}, pcl_probs.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
PCLLossesForwardLaucher(
pcl_probs.data<float>(),
labels.data<float>(),
cls_loss_weights.data<float>(),
pc_labels.data<float>(),
pc_probs.data<float>(),
img_cls_loss_weights.data<float>(),
im_labels.data<float>(),
batch_size, channels, num_positive,
output.data<float>(),
stream
);
return output;
}

at::Tensor pcl_losses_backward_cuda(const at::Tensor& pcl_probs,
const at::Tensor& labels,
const at::Tensor& cls_loss_weights,
const at::Tensor& gt_assignment,
const at::Tensor& pc_labels,
const at::Tensor& pc_probs,
const at::Tensor& pc_count,
const at::Tensor& img_cls_loss_weights,
const at::Tensor& im_labels,
const at::Tensor& top_grad) {
int batch_size = pcl_probs.size(0);
int channels = pcl_probs.size(1);
at::Tensor bottom_grad = at::zeros({batch_size, channels}, pcl_probs.options());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
PCLLossesBackwardLaucher(
top_grad.data<float>(),
pcl_probs.data<float>(),
labels.data<float>(),
cls_loss_weights.data<float>(),
gt_assignment.data<float>(),
pc_labels.data<float>(),
pc_probs.data<float>(),
pc_count.data<float>(),
img_cls_loss_weights.data<float>(),
im_labels.data<float>(),
batch_size, channels,
bottom_grad.data<float>(),
stream
);
return bottom_grad;
}

同时我们如下定义头文件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// cuda/vision.h
#pragma once
#include <torch/extension.h>
at::Tensor pcl_losses_forward_cuda(const at::Tensor& pcl_probs,
const at::Tensor& labels,
const at::Tensor& cls_loss_weights,
const at::Tensor& pc_labels,
const at::Tensor& pc_probs,
const at::Tensor& img_cls_loss_weights,
const at::Tensor& im_labels);

at::Tensor pcl_losses_backward_cuda(const at::Tensor& pcl_probs,
const at::Tensor& labels,
const at::Tensor& cls_loss_weights,
const at::Tensor& gt_assignment,
const at::Tensor& pc_labels,
const at::Tensor& pc_probs,
const at::Tensor& pc_count,
const at::Tensor& img_cls_loss_weights,
const at::Tensor& im_labels,
const at::Tensor& top_grad);

接下来就可以进入setup流程以编译对应的动态链接库。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# setup.py
#!/usr/bin/env python
import glob
import os

import torch
from setuptools import find_packages
from setuptools import setup
from torch.utils.cpp_extension import CUDA_HOME
from torch.utils.cpp_extension import CppExtension
from torch.utils.cpp_extension import CUDAExtension

requirements = ["torch", "torchvision"]

def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "csrc")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))

sources = main_file + source_cpu
extension = CppExtension

extra_compile_args = {"cxx": []}
define_macros = []

if torch.cuda.is_available() and CUDA_HOME is not None:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]

sources = [os.path.join(extensions_dir, s) for s in sources]

include_dirs = [extensions_dir]

ext_modules = [
extension(
"faster_rcnn._C",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]

return ext_modules

setup(
name="faster_rcnn",
version="0.1",
description="object detection in pytorch",
packages=find_packages(exclude=("configs", "tests",)),
# install_requires=requirements,
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)

最后,我们用PyTorch封装一下动态链接库内的函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# pcl_loss.py
import torch
from torch import nn
from torch.autograd import Function

from model.cpp.faster_rcnn import _C

class PCLLosses(Function):
@staticmethod
def forward(self, pcl_probs, labels, cls_loss_weights, gt_assignment,
pc_labels, pc_probs, pc_count, img_cls_loss_weights,
im_labels):
device_id = pcl_probs.get_device()
self.save_for_backward(pcl_probs, labels, cls_loss_weights, gt_assignment, pc_labels, pc_probs, pc_count, img_cls_loss_weights,
im_labels, torch.tensor(device_id))
output = _C.pcl_losses_forward(pcl_probs, labels, cls_loss_weights,
pc_labels, pc_probs, img_cls_loss_weights,
im_labels)
return output.sum() / pcl_probs.size(0)

@staticmethod
def backward(self, grad_output):
pcl_probs, labels, cls_loss_weights, gt_assignment, pc_labels, pc_probs, \
pc_count, img_cls_loss_weights, im_labels, device_id = self.saved_tensors

grad_input = _C.pcl_losses_backward(pcl_probs, labels, cls_loss_weights,
gt_assignment, pc_labels, pc_probs,
pc_count, img_cls_loss_weights, im_labels,
grad_output)
grad_input /= pcl_probs.size(0)
return grad_input.cuda(device_id.item()), None, None, None, None, None, None, None, None

到这里我们就完成啦。Enjoy Machine Learning!