Reverse-Engineering cuBLAS

Reverse-Engineering cuBLAS

By Fabian Schuetze

Overload, 32(181):9-13, June 2024


It’s possible to achieve cuBLAS performance with tensor cores by mimicking SASS instructions. Fabian Schuetze guides us through the process.

Glossary

A5000 (GPU): A GPU produced by Nvidia. The A5000 is based on the Ampere microarchitecture. The article uses specialized instructions introduced with Ampere. The subsequent microarchitecture (Hopper) introduced new instructions to attain maximum performance on these types of GPUs.

BLAS (and GEMM): GEMM stands for General Matrix Multiplication. Refers to a group of operations (called Level 3) of the Basic Linear Algebra Subprograms (BLAS) too. A standardized interface to BLAS will become part of C++ 26 (std::linalg) as proposed by P1673.

cuBLAS: Nvidia’s variant of the BLAS library. It contains highly optimized and specialized code for all GPU variants and matrix sizes. Its source code is not publicly accessible.

CUDA: An extension of the C language to write programs for Nvidia GPUs. CUDA affords programmers the ability to control the L1 cache of such GPUs.

PTX: PTX (Parallel Thread Execution) describes an idealized virtual machine depicting an archetypical Nvidia GPU and its corresponding instruction set architecture (ISA). Cuda code also compiles to PTX, which gets further translated to (undocumented) SASS code. Programmers can also write PTX code.

SASS: An undocumented assembly language for Nvidia GPUs. It translates to binary microcode that gets executed on an actual target.

Importance of GEMM and GPUs

Matrix multiplication is at the heart of linear algebra and the core of scientific, engineering, and statistical computation. Many variants of matrix multiplication can be expressed to interface with the Basic Linear Algebra Subprograms (BLAS). The BLAS is the de facto standard low-level interface for matrix multiplications, and its influence is hard to overstate. For example, Nature named the BLAS one of ten computer codes that transformed science [Perkel21]. Moreover, Jack J. Dongarra received the Turing Award in 2021 [ACM21] as:

the primary implementor or principal investigator for [...] BLAS. [...] The libraries are used, practically universally, for high performance scientific and engineering computation on machines ranging from laptops to the world’s fastest supercomputers.

Finally, with C++26, programmers can interface with the BLAS directly from C++ (under the std::linalg namespace), thanks to P1637.

Because the low-level interface for matrix multiplication adheres to a de facto standard and its importance, hardware vendors offer dedicated implementations. These libraries are highly optimized, but their source code is often undisclosed. Matrix multiplications comprise many small and independent computations and are well-suited for GPUs. Consequently, AMD, ARM, Nvidia, and Intel offer libraries for their GPUs. GPUs are, in essence, vector processors. They have simple (compared to modern CPUs) but enormous numbers of cores. Their memory units are also simple but provide huge throughput. To attain maximum performance, programmers commonly explicitly control data loading into caches.

This article extracts the essence of such computations by reverse-engineering a matrix multiplication with Nvidia’s BLAS library (cuBLAS). The implementation is simple yet instructive and attains performance almost on par with the cuBLAS variant. Re-engineering the cuBLAS kernel is not too difficult when using good abstractions as building blocks. The kernels provided with cuBLAS are heavily tuned, and the best-performing kernel gets selected at runtime. The runtime chooses among many kernels. One can count ~5000 kernels containing GEMM in its name, and cuBLAS ships a whopping 100MB. In comparison, the BLAS library provided by Ubuntu, libblas, ships 600KB.

The performance of three different handwritten CUDA kernels and the cuBLAS version is shown in Figure 1.

Figure 1

The three versions differ in their use of PTX (which can be understood as a mid-level IR for Nvidia GPUs) primitives and the degree of instruction-level parallelism (ILP) attained. A high ILP can be achieved by writing efficient abstractions and placing them well in the code to permit prefetching and avoiding pipeline stalls. Modern PTX instructions need to be used to permit asynchronous and highly efficient loading of global memory. This efficiency is documented by the kernels ILP, which is shown in Figure 2.

Figure 2x

Note, for users used to CPU optimization, the ILP is extremely high, which is explained by the extensive parallelism GPUs offer.

This article proceeds in the following stages: First, the basic GEMM implementation using Tensor cores is shown. Second, the SASS (CUDA assembly) code for the highly optimized CUDA kernel is analyzed, and differences between the instructions of the basic implementation are identified. The basic implementation is refined in two steps to reach performance parity with cuBLAS.

Basic GEMM Implementation

The main loop of the basic implementation of a GEMM kernel with tensor cores is in Listing 1. This documents the basic structure of a decent GEMM kernel with tensor cores: Looping along the K (inner) dimension of the matrix product in blocks, the kernel loads blocks of the matrices A and B into shared memory. The load function is named load_blocking (which already provides a glimpse at future optimizations). The kernel then uses a nested loop to compute the matrix product over these blocks. Smaller blocks of the shared memory get loaded into local register files, and their matrix product gets calculated. The kernel reaches about 60TFLOPS on an A5000, or ⅔ of the GPU limit.

for (size_t block = 0; block < K; block += Threadblock::kK) {
  LoaderA.load_blocking();
  LoaderB.load_blocking();
  LoaderA.next(Threadblock::kK);
  LoaderB.next(Threadblock::kK * N);
  __syncthreads();
  constexpr size_t wmma_steps 
    = Threadblock::kK / WMMAblock::kK;
  for (size_t wmma_step = 0; 
       wmma_step < wmma_steps; wmma_step++) {
    RegisterLoaderA.load();
    RegisterLoaderB.load();
    RegisterLoaderA.step(WMMAblock::kK);
    RegisterLoaderB.step
      (Bs.cols_ * WMMAblock::kN);
    matmul.compute();
  }
  RegisterLoaderA.reset(0);
  RegisterLoaderB.reset(0);
  __syncthreads();
}
Listing 1

The code in Listing 1 gets compiled to the following SASS assembly:

  ...
  LDG.E.128.CONSTANT R72, [R72.64]
  ...
  WARPSYNC 0xffffffff
  ...
  STS.128 [R143], R52
  ...
  BAR.SYNC 0x0
  LDSM.16.M88.4 R80, [R80]
  ...
  HMMA.16816.F16 R18, R80, R68, R18
  ...
  BAR.SYNC 0x0

The assembly reveals the inner workings of the code above: First, load_blocking stores 128 bits from global memory into thread-local registers. After the global loads, all threads in the warp wait at a barrier. Then, the threads store the loaded data in shared memory, and all threads in a block sync. Furthermore, data from shared memory is loaded as a matrix for processing by the tensor cores. Then, a tensor core matrix multiplication with half-floats ensues. Finally, all threads in the block wait at a barrier before the loop starts again. The way data is loaded is pictured in the graph in Figure 3.

Figure 3

From the very right, 255MB are loaded from device memory to the L2 Cache before landing in the L1 Cache. As can be seen in the top left of the figure, there are 3.41M instructions used to load data into the local registers. From the local registers, the data is stored again in the shared memory (a portion of the L1 cache) in 3.15M requests. From the shared memory, the data gets accessed in 11.53M requests.

SASS code for cuBLAS assembly code

The SASS code for the cuBLAS kernel is interesting. An abbreviated version reads as follows:

  HMMA.16816.F32 R0, R152, R184, R0
  LDSM.16.MT88.4 R168, [R137+UR8+0x800]
  LDGSTS.E.BYPASS.LTC128B.128.CONSTANT
  [R129+UR4+0x3000], [R130.64+0x180]
  ...
  HMMA.16816.F32 R4, R152, R186, R4
  HMMA.16816.F32 R8, R152, R188, R8
  ...
  HMMA.16816.F32 R120, R164, R196, R120
  DEPBAR.LE SB0, 0x1
  ...

The assembly code highlights several aspects: The main loop starts with a matrix multiplication instead of a memory load. The global load LDGSTS.E.BYPASS.LTC128B.128.CONSTANT differs from the load in the basic GEMM implementation, LDG.E.128.CONSTANT R72: Firstly, it bypasses the register and stores the data directly in shared memory. Furthermore, it is an asynchronous load and does not block the threads. Non-blocking requires a separate memory fence to signal when the data is ready. Such a barrier is the dependency barrier DEPBAR.LE. Finally, the instructions are interleaved: There is no linear separation between loading data and operating on it, but a heavy mixture of instructions. The cuBLAS kernel achieves ~90TFLOPS. The following two kernels describe how to write code that produces similar SASS and attains the same performance.

Improvement I: buffering

Asynchronous load instructions

Starting with PTX Version 7.0 [PTX-1], CUDA provides instructions to copy data asynchronously from global to shared memory. The copy bypasses local registers and stores data directly to the shared memory (L1 cache). As identified above, asynchronous loading is one of the differences between the simple GEMM code and the cuBLAS version.

Two changes are necessary for asynchronous loading. First, the new load function is in Listing 2. What it was before is shown in Listing 3.

__device__ void load(size_t counter) {
  const size_t global_idx = 
    offset_.row * ld_ + offset_.col;
  for (size_t row = 0; row < rows;
       row += stride_) {
    const T *src = 
      global_ptr_ + row * ld_ + global_idx;
    T *dst = 
      &shmem_(counter * rows + offset_.row + row,
              offset_.col); // + row * cols;
    constexpr size_t load_bytes = 16;
    uint32_t pos_in_ss = __cvta_generic_to_shared
      (reinterpret_cast<int4 *>(dst));
    CP_ASYNC_CG(pos_in_ss, src, load_bytes);
  }
}
Listing 2
__device__ void load_blocking() {
  const size_t global_idx = 
    offset_.row * ld_ + offset_.col;
  for (size_t row = 0; row < rows;
       row += stride_) {
    const T *src = 
      global_ptr_ + row * ld_ + global_idx;
    T *dst = &shmem_(offset_.row + row,
      offset_.col); // + row * cols;
    const int4 t = 
      reinterpret_cast<const int4 *>(src)[0];
    reinterpret_cast<int4 *>(dst)[0] = t;
  }
}
Listing 3

The load_blocking function loads 128bit by casting eight half floats as an int4 and loads it. In contrast, the load function uses the macro CP_ASTNC_CG comprising the PTX instructions in Listing 4.

#define CP_ASYNC_CG(dst, src, Bytes)            \
  asm volatile(                                 \
    "cp.async.cg.shared.global.L2::128B [%0],"  \
    "[%1], %2;\n" ::"r"(dst), "l"(src),         \
    "n"(Bytes))
Listing 4

The compiler converts it into the same SASS instruction as can be seen in the cuBLAS code:

  LDGSTS.E.BYPASS.LTC128B.128 [R11], [R2.64]

Because the load is non-blocking, a separate memory fence is needed to synchronize the threads. As stated in the PTX manual [PTX-2], asynchronous copies need to be committed to a group and waited for. The following two macros, comprising PTX instructions, do exactly that:

  CP_ASYNC_COMMIT_GROUP();
  CP_ASYNC_WAIT_GROUP(0);

These two macros get compiled into the following SASS code:

  LDGDEPBAR
  DEPBAR.LE SB0, 0x0

These two SASS instructions are found in the cuBLAS code too. The slight difference between the two is covered in the next section. Visualizing the new load instruction LDGSTS.E.BYPASS.LTC128B.128 is very instructive (see Figure 4). The data goes directly from the L2 Cache through the shared memory (a portion of the L1 cache).

Figure 4

Overlapping memory loads with computation

The gift of asynchronous copy operations is that one can overlay computation with memory transfers and avoid pipeline stalls. The kernel can be expressed as shown in Listing 5.

size_t counter = 0;
LoaderA.load(counter);
LoaderB.load(counter);
LoaderA.next(Threadblock::kK);
LoaderB.next(Threadblock::kK * N);
CP_ASYNC_COMMIT_GROUP();
CP_ASYNC_WAIT_GROUP(0);
__syncthreads();
for (size_t block = 0;
  block < K - Threadblock::kK;
  block += Threadblock::kK) {
  LoaderA.load(counter ^ 1);
  LoaderB.load(counter ^ 1);
  LoaderA.next(Threadblock::kK);
  LoaderB.next(Threadblock::kK * N);
  constexpr size_t wmma_steps = 
    Threadblock::kK / WMMAblock::kK;
  for (size_t wmma_step = 0; 
      wmma_step < wmma_steps; ++wmma_step) {
    RegisterLoaderA.load();
    RegisterLoaderB.load();
    RegisterLoaderA.step(WMMAblock::kK);
    RegisterLoaderB.step
      (Bs.cols_ * WMMAblock::kN);
    matmul.compute();
  }
  counter ^= 1;
  RegisterLoaderA.reset(counter *
    Threadblock::kM * (Threadblock::kK + skew));
  RegisterLoaderB.reset(counter * 
    Threadblock::kK * (Threadblock::kN + skew));
  CP_ASYNC_COMMIT_GROUP();
  CP_ASYNC_WAIT_GROUP(0);
  __syncthreads();
}
for (size_t bk = 0; bk < Threadblock::kK;
     bk += WMMAblock::kK) {
  RegisterLoaderA.load();
  RegisterLoaderB.load();
  RegisterLoaderA.step(WMMAblock::kK);
  RegisterLoaderB.step(Bs.cols_ * WMMAblock::kN);
  matmul.compute();
}
Listing 5

The computation starts by loading data from global to shared memory. The class loading data from shared to global memory manages two buffers. Data gets read from one buffer and stored in the other buffer. The main loop begins by initiating a global memory load. The matrix elements are then computed. Afterward, the threads block until the previously fetched memory has been loaded. In the loop’s epilogue, the last outstanding matrix computation is conducted.

This kernel attains 73 TFLOPS, a 20 percent increase to the first kernel.

Improvement II: double buffering

The code above already improves the throughput of the kernel. However, it is still below the cuBLAS version, and the assembly instructions do not match. In particular, the memory barrier in the code above is DEPBAR.LE SB0, 0x0, but the memory barrier in the cuBLAS code is DEPBAR.LE SB0, 0x1. The SASS instructions are undocumented, but one can assume that LE stands for less or equal. Furthermore, the PTX docs for the memory barrier [PTX-3] state that the PTX instruction cp.async.wait_group N is:

cp.async.wait_group instruction will cause the executing thread to wait till only N or fewer of the most recent cp.async-groups are pending and all the prior cp.async-groups committed by the executing threads are complete.

Besides the difference in instructions, the kernel above also regularly stalled because data was unavailable. The warps stalled for almost two cycles for each issued instruction because data was unavailable (long scoreboard stall). To avoid such stalls and replicate the SASS code for the cuBLAS kernel, the kernel below does “double buffering”: Always have two shared memory operations in flight and await only the oldest one. Register loads are buffered too. The kernel has one register file loaded, loads the next one, and computes the matrix operation on the previous register file. The code for the kernel is in Listing 6.

LoaderA.load(0);
LoaderB.load(0);
LoaderA.next(Threadblock::kK);
LoaderB.next(Threadblock::kK * N);
CP_ASYNC_COMMIT_GROUP();
LoaderA.load(1);
LoaderB.load(1);
LoaderA.next(Threadblock::kK);
LoaderB.next(Threadblock::kK * N);
CP_ASYNC_COMMIT_GROUP();
CP_ASYNC_WAIT_GROUP(1); // 1 = Wait until 1 
         // recent async groups are pending
__syncthreads();
RegisterLoaderA.load(0);
RegisterLoaderB.load(0);
RegisterLoaderA.step(WMMAblock::kK);
RegisterLoaderB.step
  (SpanTypeB::cols_ * WMMAblock::kN);
size_t counter = 1;
for (size_t block = 0; block < K - 2 * Threadblock::kK;
  block += Threadblock::kK) {
    constexpr size_t wmma_steps =
      Threadblock::kK / WMMAblock::kK;
    for (size_t i = 0; i < wmma_steps; ++i) {
      size_t current = i % 2;
      size_t next = (i + 1) % 2;
      RegisterLoaderA.load(next);
      RegisterLoaderB.load(next);
      RegisterLoaderA.step(WMMAblock::kK);
      RegisterLoaderB.step
        (SpanTypeB::cols_ * WMMAblock::kN);
      matmul.compute(current);
      if (i == 0) {
        LoaderA.load(counter ^ 1);
        LoaderB.load(counter ^ 1);
        LoaderA.next(Threadblock::kK);
        LoaderB.next(Threadblock::kK * N);
        CP_ASYNC_COMMIT_GROUP();
        CP_ASYNC_WAIT_GROUP(1);
        __syncthreads();
        RegisterLoaderA.reset
          (counter * MemLoaderA::size_);
        RegisterLoaderB.reset
          (counter * MemLoaderB::size_);
        counter ^= 1;
      }
    }
  __syncthreads();
}
Listing 6

The prologue to the main loop begins by issuing two shared memory loads. The threads block until the first load is completed, while the second one remains in flight. Then, the first register file is loaded. The main loop begins by loading a further fragment of shared memory, and the tensor cores operate on the previous fragment. When all local registers are filled, the shared memory of the first block has been exhausted. No computation can be overlaid over the memory copies anymore. Another load is issued, and the warps wait until the previous load is completed.

With these advances, the throughput of the kernel advances to 89 TFLOPS and reaches within 95% of cuBLAS performance. Further gains can be reaped by writing the result of the multiplication through shared memory back to global memory. The kernel throughput then advances to 91 TFLOPS, 1 TFLOP behind the cuBLAS kernel.

References

[ACM21] ACM Turing Award 2021: available at https://awards.acm.org/about/2021-turing

[Perkel21] Jeffrey M. Perkel ‘Ten computer codes that transformed science’, published on Nature website 20 January 2021 (last updated 8 April 2021), available at https://www.nature.com/articles/d41586-021-00075-2.

[PTX-1] PTX Version 7.0 documentation, ‘Changes in PTX ISA Version 7.0’, published by NVIDIA, available at https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#changes-in-ptx-isa-version-7-0

[PTX-2] PTX Version 7.0 documentation, ‘Data Movement and Conversion Instructions: Asynchronous copy’, published by NVIDIA, available at https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=async#data-movement-and-conversion-instructions-asynchronous-copy

[PTX-3] PTX Version 7.0 documentation, ‘Data Movement and Conversion Instructions: cpl.async.wait_group/cp.async.wait_all‘, published by NVIDIA, available at https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=async#data-movement-and-conversion-instructions-cp-async-wait-group-cp-async-wait-all

This article was previously published on github by Fabian on 14 March 2024, and is available at https://fabianschuetze.github.io/category/articles.html

Fabian Schuetze Fabian works on computer vision and AI in the automotive and robotics industry. When not working, he’s enjoying running or drinking wine, though not at the same time.






Your Privacy

By clicking "Accept Non-Essential Cookies" you agree ACCU can store non-essential cookies on your device and disclose information in accordance with our Privacy Policy and Cookie Policy.

Current Setting: Non-Essential Cookies REJECTED


By clicking "Include Third Party Content" you agree ACCU can forward your IP address to third-party sites (such as YouTube) to enhance the information presented on this site, and that third-party sites may store cookies on your device.

Current Setting: Third Party Content EXCLUDED



Settings can be changed at any time from the Cookie Policy page.