Post

Torch extension 사용법

Torch extension 사용법

1. 개요

원래 GPU를 구동할 수 있는 코드는 C++을 이용해서만 짤 수 있었다.
하지만 개발간 Python이 매우 많이 사용됨으로 인해, C++로 짜여진 커널을 Python에서 구동할 수 있으면 좋겠다는 니즈가 생겼고 이로 인해 많은 라이브러리들이 CUDA 커널을 지원하기 시작했다.

이번에는 많은 라이브러리들 중에 Torch extension에 대해서 이야기할 것이다. PyTorch Extension은 C++/CUDA 커스텀 코드를 PyTorch에 통합하여 성능을 최적화하거나 새로운 연산을 추가할 수 있게 해주는 기능이다. 기본 PyTorch 연산으로 표현할 수 없거나 성능 최적화가 필요한 경우에 사용한다.

2. 방법

총 2가지 사용 방법이 있다.

1) setuptools를 이용한 Ahead-of-Time 빌드

가장 일반적인 방법으로, setup.py 파일을 작성하여 사전 컴파일한 뒤 Python에서 호출해서 사용하는 방식이다.

순서는 아래와 같다.

  1. C++ 로 CUDA 커널을 짠다.
  2. C++ 로 인터페이스를 짠다.
  3. CUDA 커널과 인터페이스를 Bind 한다.
  4. Extension을 setup tool로 컴파일 한다.
  5. 컴파일된 Extension을 Python에서 호출해서 사용한다.

a. 사용 예시

이번 예시에서는 행렬 곱셈을 하는 커널을 만들어서 pytorch에서 사용할 수 있도록 해보겠다.
$N \times N $ 정방 행렬 간의 곱셈을 하는 코드를 만들어서 구동해보도록 하겠다.

  1. 디렉터리 구조 생성 아래대로 디렉터리 구조를 생성한다.
    1
    2
    3
    4
    5
    6
    
    src/
    ├─ my_extension.cpp
    ├─ my_kernel.h
    ├─ my_kernel.cu
    setup.py
    test.py
    

실질적으로 CUDA 및 C++ 코드가 들어있는 src 폴더와 코드를 빌드하기 위한 setup.py 파일 그리고 제대로 빌드되었는지 확인하기 위한 test.py 이다.

  1. CUDA 커널 생성 my_kernel.cu 라는 이름으로 파일을 만들고 커널 코드를 작성한다. ```cuda #include #include "my_kernel.h"

global void matrixMulKernel(const float* A, const float* B, float* C, int N) { int row = blockIdx.y * blockDim.y + threadIdx.y; int col = blockIdx.x * blockDim.x + threadIdx.x; if (row < N && col < N) { float value = 0; for (int k = 0; k < N; ++k) { value += A[row * N + k] * B[k * N + col]; } C[row * N + col] = value; } }

void matrixMul(const float* A, const float* B, float* C, int N) { float *d_A, *d_B, *d_C; size_t size = N * N * sizeof(float);

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
cudaMalloc(&d_A, size);
cudaMalloc(&d_B, size);
cudaMalloc(&d_C, size);

cudaMemcpy(d_A, A, size, cudaMemcpyHostToDevice);
cudaMemcpy(d_B, B, size, cudaMemcpyHostToDevice);

dim3 threadsPerBlock(16, 16);
dim3 numBlocks((N + threadsPerBlock.x - 1) / threadsPerBlock.x, 
               (N + threadsPerBlock.y - 1) / threadsPerBlock.y);
matrixMulKernel<<<numBlocks, threadsPerBlock>>>(d_A, d_B, d_C, N);

cudaMemcpy(C, d_C, size, cudaMemcpyDeviceToHost);

cudaFree(d_A);
cudaFree(d_B);
cudaFree(d_C); }
1
2
3
4
5
6
7
8
9
10
11
12
13
2. kernel Header 생성
해당 Kernel의 함수를 인식하게 할 수 있는 커널 헤더를 만든다. 커널에 정의한 함수와 동일한 반환값과 인자를 주어야하며   
실질적으로 외부에 노출되는 함수이다.
```cpp
#ifndef MY_KERNEL_H
#define MY_KERNEL_H

#include <torch/extension.h>

void matrixMul(const float* A, const float* B, float* C, int N);

#endif // MY_KERNEL_H
  1. CPP Wrapper 생성 커널 함수를 CPP로 감싸준다. 여기서 TORCH EXTENSION 헤더에서 제공하는 것들로 타입 체크와 그외의 필요한 바인딩을 해준다. ```cpp #include <torch/extension.h> #include “my_kernel.h”

void multiply_matrices(torch::Tensor a, torch::Tensor b, torch::Tensor result) {

1
2
3
4
5
6
7
8
9
TORCH_CHECK(a.device() == b.device(), "Input tensors must be on the same device");
TORCH_CHECK(a.device() == result.device(), "Result tensor must be on the same device");


TORCH_CHECK(a.size(1) == b.size(0), "Incompatible dimensions for matrix multiplication");


int64_t rows = a.size(0);
matrixMul(a.data_ptr<float>(), b.data_ptr<float>(), result.data_ptr<float>(), rows); }

// Bind the functions to Python PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def(“multiply”, &multiply_matrices, “Matrix multiplication (CUDA)”); }

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
4. Build를 위한 setup.py 생성
여기서 빌드할때 필요한 파일을 추가하며 cmdclass에서 빌드간 사용하기 위한 명령어를 정의한다.
```python
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
  name='cuda_matrix_ops',
  ext_modules=[
    CUDAExtension(
      'cuda_matrix_ops',
      ['src/my_extension.cpp', 'src/my_kernel.cu'],
    ),
  ],
  cmdclass={
    'build_ext': BuildExtension
  }
)
  1. 빌드 실제로 아래 명령어를 통해 빌드한다.
    1
    
    python3 setup.py build_ext --inplace
    
  2. 테스트 실제 구동되면 Result 이후에 4x4 배열의 곱이 나온다. ```python import sys import torch

현재 디렉토리를 경로에 추가 (로컬 .so 파일 인식)

sys.path.insert(0, ‘{빌드한 so 파일이 있는 폴더 경로}’)

import cuda_matrix_ops

def test(): device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’) a = torch.randn(4, 4, device=device, dtype=torch.float32) b = torch.randn(4, 4, device=device, dtype=torch.float32) result = torch.zeros(4, 4, device=device, dtype=torch.float32)

1
2
cuda_matrix_ops.multiply(a, b, result)
print("Result:", result)

if name == “main”: test()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
### 2) Just-In-Time (JIT) 컴파일
총 두 가지 방식이 있는데 커널과 바인딩 코드를 포함하는 .cpp 파일과 .cu 파일을  torch.utils.cpp_extension.load()을 이용하여 불러와서 사용하는 방식과
c++ 소스 코드를 Python 문자열로 torch.utils.cpp_extension.load_inline()에 직접 전달하여 사용하는 방식 두 가지가 있다.

아래에서 예시와 함께 살펴보겠다.

#### a. 파일 불러와서 사용
1. C++ 로 CUDA 커널을 짠다.
2. C++ 로 인터페이스를 짠다.   
※ setuptools를 이용한 Ahead-of-Time 빌드 방식에서 사용하던 예시를 그대로 가져오겠다.   
3. Python에서 커널 바인드 및 컴파일해서 호출해서 사용한다.
```python3
from torch.utils.cpp_extension import load
import torch

cuda_matrix_ops = load(
  name='my_extension',
  sources=['src/my_kernel.cu', 'src/my_extension.cpp'],
  verbose=True
)

def test():
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  a = torch.randn(4, 4, device=device, dtype=torch.float32)
  b = torch.randn(4, 4, device=device, dtype=torch.float32)
  result = torch.zeros(4, 4, device=device, dtype=torch.float32)

  cuda_matrix_ops.multiply(a, b, result)
  print("Result:", result)

if __name__ == "__main__":
  test()

b. Python 코드 내에서 사용

위에서 사용하는 코드를 그대로 옮겼습니다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from torch.utils.cpp_extension import load_inline
import torch

cuda_source = """
#include <cuda_runtime.h>

__global__ void matrixMulKernel(const float* A, const float* B, float* C, int N) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    if (row < N && col < N) {
        float value = 0;
        for (int k = 0; k < N; ++k) {
            value += A[row * N + k] * B[k * N + col];
        }
        C[row * N + col] = value;
    }
}

void matrixMul(const float* A, const float* B, float* C, int N) {
    float *d_A, *d_B, *d_C;
    size_t size = N * N * sizeof(float);
    
    cudaMalloc(&d_A, size);
    cudaMalloc(&d_B, size);
    cudaMalloc(&d_C, size);
    
    cudaMemcpy(d_A, A, size, cudaMemcpyHostToDevice);
    cudaMemcpy(d_B, B, size, cudaMemcpyHostToDevice);
    
    dim3 threadsPerBlock(16, 16);
    dim3 numBlocks((N + threadsPerBlock.x - 1) / threadsPerBlock.x, 
                   (N + threadsPerBlock.y - 1) / threadsPerBlock.y);
    matrixMulKernel<<<numBlocks, threadsPerBlock>>>(d_A, d_B, d_C, N);
    
    cudaMemcpy(C, d_C, size, cudaMemcpyDeviceToHost);
    
    cudaFree(d_A);
    cudaFree(d_B);
    cudaFree(d_C);
}
"""

cpp_source = """
#include <torch/extension.h>

void matrixMul(const float* A, const float* B, float* C, int N);

// Function to multiply two matrices
void multiply_matrices(torch::Tensor a, torch::Tensor b, torch::Tensor result) {
    // Check that the input tensors are on the same device
    TORCH_CHECK(a.device() == b.device(), "Input tensors must be on the same device");
    TORCH_CHECK(a.device() == result.device(), "Result tensor must be on the same device");
    
    // Check that the dimensions are compatible for multiplication
    TORCH_CHECK(a.size(1) == b.size(0), "Incompatible dimensions for matrix multiplication");
    
    // Launch the CUDA kernel for matrix multiplication
    int64_t rows = a.size(0);
    matrixMul(a.data_ptr<float>(), b.data_ptr<float>(), result.data_ptr<float>(), rows);
}

"""

cuda_matrix_ops = load_inline(
    name='my_extension',
    cpp_sources=[cpp_source],
    cuda_sources=[cuda_source],
    functions=['multiply_matrices'],
    verbose=True
)

def test():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    a = torch.randn(4, 4, device=device, dtype=torch.float32)
    b = torch.randn(4, 4, device=device, dtype=torch.float32)
    result = torch.zeros(4, 4, device=device, dtype=torch.float32)
    
    cuda_matrix_ops.multiply_matrices(a, b, result)
    print("Result:", result)

if __name__ == "__main__":
    test()

이전 방식에서 PYBIND11_MODULE로 바인딩하던 것은 load_inline에서 해주기 때문에 함수이름이 변경되었습니다.

참고문헌

  • https://tutorials.pytorch.kr/advanced/cpp_custom_ops.html
This post is licensed under CC BY 4.0 by the author.