如何优化gemm算子

最近在做开源之下的项目,记录一下学习 gemm 的优化过程。

gemm 即矩阵乘法,通常的 gemm 可表示为一个 MxK 的矩阵 A 和一个 KxN 的矩阵 B 相乘得到一个 MxN 的矩阵 C。矩阵计算的优化思路可以总结为以下几步,在这个博客中省去了对计算机组成原理的一些介绍。

# 优化思路

  1. Split MNK

    这一步优化主要是基于计算机体系架构的 cache 结构,节省计算数据在 cache 和内存间读取的时间
    切分 MNK(根据 L1 cache 的大小),比如 arm64 中 FP32M8N12 就是将矩阵 A(MxK)切分为若干个(8xK)矩阵,将矩阵 B(KxN)切分为若干个(Kx12)的矩阵,在第 2 步中提到了为何这么切分。

  2. 分层计算

    这一步优化主要是节省中间数据在内存和 cache 之间换入换出的时间。
    传统的 MNK 矩阵切分计算后,应该是 A 的小矩阵的每一乘上 B 的小矩阵的每一,这种做法比较低效,且存在很多冗余的重复加载。
    分层计算的核心是每次用 A 的小矩阵的每一(8 个元素,1X8 矩阵)和 B 的小矩阵的每一(12 个元素,12X1 矩阵)相乘,得到中间结果矩阵(一层 8x12 矩阵),在这个过程中,把每一个中间结果矩阵累加,就得到了小矩阵 8K 和小矩阵 K12 计算的结果(完整 8x12 矩阵)。
    以 arm64 架构的 CPU 为例子,有 32 个 128 位的 neon 向量寄存器,那么 M8N12 的每次分层计算中,A 需要从内存中读取 8 个元素,占用 2 个寄存器;B 需要从内存中读取 12 个元素,占用 3 个寄存器;C(中间结果)的形状为 8x12,需要占用 24 个寄存器来累加运算。

  3. SIMD 优化

    这一步优化主要是通过并行计算提高数据运算的速度。在 cache 计算中使用 SIMD 优化,并行计算元素与元素相乘的过程。

  4. 内存重排(pack)

    这一步的优化主要是节省数据在内存中读取的时间
    以 x86 的 CPU 为例,矩阵在内存中是逐行存储的,一个 4x4 的矩阵 A 的存储顺序(由小到大)为

    | 0 | 1 | 2 | 3 |
    | 4 | 5 | 6 | 7 |
    | 8 | 9 | A | B |
    | C | D | E | F |

    假设该矩阵被分割成 2 个 2x2 的矩阵,在计算中需要逐列读取,那么第一个被读取的元素是 0 处的,第二个被读取的元素是 4 处的,他们的地址并不连续,这会浪费一定的寻址时间。因此,在 2x2 的分割下,需要针对该矩阵进行内存重排,重排后的顺序如下,这种重排叫做 zigzag,因为是按照之字型排序的。

    | 0 | 2 | 4 | 6 |
    | 1 | 3 | 5 | 7 |
    | 8 | A | C | E |
    | 9 | B | D | F |

# Example

以 Megcc 生成的算子 Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS 为例子,分析一下如何实现 gemm 优化过后的算子。这是按照 M8N12K4 的分割的,想象一个三维的长方体分割就可以。

  1. 重排内存

    kernel 的执行入口是 void Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS ,该程序首先确认矩阵的内存空间。

    • Code

      1
      2
      3
      4
      5
      6
      7
      size_t pack_a_size = Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS_workspace_a(0, M, 0, K);
      float* pack_a = workspace; # 矩阵A的开始地址
      float* pack_b = workspace + pack_a_size; #矩阵B的开始地址
      Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS_packa_n(pack_a, A, LDA, 0, M, 0, K);
      # 按照上述第4步对矩阵进行重排
      Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS_packb_n(pack_b, B, LDB, 0, N, 0, K);
      # 按照上述第4步对矩阵进行重排
  2. 矩阵计算

    矩阵的计算是逐 Block 进行的, M8N12K4 的计算方式中,每一个块的大小为 8x12x4。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    for (; m + m_block <= M; m += m_block) {
    float* output = C + (m / pack_mk * LDC);

    size_t n = 0;
    const float* cur_pack_b = pack_b;
    for (; n + n_block <= N; n += n_block) {
    kern_8x12_bias_relu(pack_a, cur_pack_b, K, output, LDC, _bias_ptr); # 每一个小block的计算
    output += n_block * pack_mk;
    ;
    cur_pack_b += K12;
    }
    # 当n不能被12整除时,剩下的n按照4分割计算。
    for (; n < N; n += 4) {
    kern_8x4_bias_relu(
    pack_a, cur_pack_b, K, output, LDC, _bias_ptr,
    N - n > 4 ? 4 : N - n);
    # 即使剩下的不能被4整除也按照4向上对齐处理。
    output += 4 * pack_mk;
    ;
    cur_pack_b += K4;
    }
    pack_a += K8;
    }
    #当m不能被8整除时,剩下的m的维度按照4整除计算。
    for (; m < M; m += m_block_4) {
    float* output = C + (m / pack_mk * LDC);

    size_t n = 0;
    const float* cur_pack_b = pack_b;
    for (; n + n_block - 1 < N; n += n_block) {
    kern_4x12_bias_relu(pack_a, cur_pack_b, K, output, LDC, _bias_ptr);
    output += n_block * pack_mk;
    ;
    cur_pack_b += K12;
    }
    #同上
    for (; n < N; n += 4) {
    kern_4x4_bias_relu(
    pack_a, cur_pack_b, K, output, LDC, _bias_ptr,
    N - n > 4 ? 4 : N - n);
    output += 4 * pack_mk;
    ;
    cur_pack_b += K4;
    }
    pack_a += K4;
    }
  3. 汇编实现:内存重排

    对 A 进行内存重排的程序入口如下,矩阵 A 的形状为 [M/4, K/4, 4, 4],如下图所示,A 矩阵在内存中首先以块顺序 0-7 存储,在每一块内按照 0-F 的顺序存储。对 A 进行内存重排后的形状为 [M/8, K, 8],在图中为了方便理解,箭头方向为数据存储的方向,假设 M=16,K=8。

    以下程序描述了如何进行内存重排的过程。首先需要两个指针 inptr0inptr1 定位原矩阵的内存,其中 inptr0 定位的是第 0 矩阵块(以下都称为块,描述的是上图中矩阵的分块情况)内存起点, inptr1 定位的是第 2 块内存起点, prefetch_2x 函数的作用是读取 32 个 fp32 数据,即 2 块,该函数将 0、1 和 2、3 块数据加载到缓存中, interleave_2x4_4_s 是重排函数,作用是将载入缓存的块数据重排,具体如何展开请看后续。注意,若 m 不被 8 整除,需要在下一个循环代码中处理尾部数据。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    void Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS_packa_n(
    float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0,
    int kmax) {
    const int pack_mk = 4;
    const int pack_m = 8;
    const int m_stride = pack_m * pack_mk;
    const int min_m_stride = pack_mk * pack_mk;
    int y = 0;
    for (; y + 7 < ymax; y += pack_m) {
    const float* inptr0 = inptr + y / pack_mk * ldin;
    const float* inptr1 = inptr0 + ldin;
    prefetch_2x(inptr0);
    prefetch_2x(inptr1);
    int k = (kmax);
    for (; k > 3; k -= pack_mk) {
    interleave_2x4_4_s(inptr0, inptr1, outptr);
    outptr += m_stride;
    inptr0 += min_m_stride;
    inptr1 += min_m_stride;
    }
    }
    # m不被8整除
    for (; y < ymax; y += pack_mk) {
    const float* inptr0 = inptr + y / pack_mk * ldin;
    prefetch_2x(inptr0);
    int K = (kmax);
    for (; K > 3; K -= pack_mk) {
    interleave_1x4_4_s(inptr0, outptr);
    outptr += min_m_stride;
    inptr0 += min_m_stride;
    }
    }
    }

    重排函数实现的重点是函数 interleave_2x4_4_s ,下图描述了矩阵 A 内存重排的实现过程。其中 v0-v7 为寄存器。程序首先将数据加载到寄存器后,重排顺序后再输出到矩阵 A 的输出位置指针。

    对 B 实现内存重排的原理类似,其原始形状为 [K/4, N, 4],重排后形状为 [N/12, K, 12],其形状变化和内存排序如下图所示,和 A 相比,需要多进行一个类似转置的操作。如下图所示,该重排把一个 4x12 的矩阵块(第一块绿色,紫色,黄色的三块)重排成了一个 12x4 的矩阵块。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    void Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS_packb_n(
    float* outptr, const float* inptr, int ldin, int x0, int xmax, int k0,
    int kmax) {
    float tmpbuff[16] = {0.0f};

    int PACK_C_SIZE = 4;
    int ksize = kmax - k0;
    int ksize12 = ksize * 12;
    int ksize4 = (ksize << 2);
    float* outptr_base = outptr;
    float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12;

    int k = k0;
    for (; k + 3 < kmax; k += 4) {
    const float* temp_inptr = inptr + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE;
    prefetch_3x(temp_inptr);

    int x = x0;
    float* temp_outptr = outptr_base;
    for (; x + 12 <= xmax; x += 12) {
    float* outptr_interleave = temp_outptr;
    transpose_1x12_4_s(temp_inptr, outptr_interleave);
    temp_outptr += ksize12;
    temp_inptr += 4 * 12;
    }
    temp_outptr = outptr_base4;
    for (; x + 4 <= xmax; x += 4) {
    float* outptr_interleave = temp_outptr;
    asm volatile(
    "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
    "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr0]]\n"
    : [inptr0] "+r"(temp_inptr), [outptr0] "+r"(outptr_interleave)
    :
    : "v0", "v1", "v2", "v3", "memory");
    temp_outptr += ksize4;
    }
    if (x < xmax) {
    memcpy(tmpbuff, temp_inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE);
    float* outptr_interleave = temp_outptr;
    const float* tmp_ptr = &tmpbuff[0];
    asm volatile(
    "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
    "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr0]]\n"
    : [inptr0] "+r"(tmp_ptr), [outptr0] "+r"(outptr_interleave)
    :
    : "v0", "v1", "v2", "v3", "memory");
    temp_outptr += ksize4;
    }
    outptr_base += 12 * 4;
    outptr_base4 += 4 * 4;
    }
    }
  4. 汇编实现: kernel func

    根据实现细节,实现一个 M8N12K4 的算子需要以下几个分层运算函数: kern_8x12_bias_relu , kern_8x4_bias_relu , kern_4x12_bias_relu , kern_4x4_bias_relu , 这些都是基于汇编语言实现的。以主要的 kern_8x12_bias_relu 为例子,来看一下这个函数是如何实现分层计算的。
    需要理解的汇编语言操作如下表

    1. eor 异或操作
    2. ld1 从内存中加载数据
    3. prfm 预取内存到缓存
    4. cmp 比较操作符
    5. fmla 向量乘加操作
    6. bne 比较结果不相等时跳转
    7. st1 将寄存器的数据写入内存

    根据优化思路的第 2 步,首先你需要加载 A 的 8 个元素(2 个寄存器)和 B 的 12 个元素(3 个寄存器)到 cache 中,再为中间结果矩阵准备 24 个寄存器。
    在以下这段汇编代码中,首先通过异或操作 eor 清空了 8-31 寄存器的所有字节(x.16b),一共 24 个,这些寄存器是为了存储中间结果准备的。用 output0 和 output1 两个指针作为输出,并通过 prfm 把他们预取到缓存中。以下是汇编代码的主体实现,实际使用时,arm 架构有 32 个 128 位的寄存器,除去以上的 29 个,另外空闲的 3 个处理器分给 B 做数据的预取。

    • Code

      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      30
      31
      32
      33
      34
      35
      36
      37
      38
      39
      40
      41
      42
      43
      44
      45
      46
      47
      48
      49
      50
      51
      52
      53
      54
      55
      56
      57
      58
      59
      60
      61
      62
      63
      64
      65
      66
      67
      68
      69
      70
      71
      72
      73
      74
      75
      76
      77
      78
      79
      80
      81
      82
      83
      84
      85
      86
      87
      88
      89
      90
      91
      92
      93
      94
      95
      96
      97
      98
      99
      100
      101
      "eor  v8.16b, v8.16b, v8.16b     \n" # 清空v8寄存器,下同
      "eor v9.16b, v9.16b, v9.16b \n"
      "eor v10.16b, v10.16b, v10.16b \n"
      "prfm pstl1keep, [%[output0]] \n" # 预取
      "eor v11.16b, v11.16b, v11.16b \n"
      "eor v12.16b, v12.16b, v12.16b \n"
      "eor v13.16b, v13.16b, v13.16b \n"
      "prfm pstl1keep, [%[output1]] \n"
      "eor v14.16b, v14.16b, v14.16b \n"
      "eor v15.16b, v15.16b, v15.16b \n"
      "ld1 {v2.4s}, [%[b_ptr]], #16 \n" # v2加载B的四个字节,同时将指针向后移16个字节
      "eor v16.16b, v16.16b, v16.16b \n"
      "eor v17.16b, v17.16b, v17.16b \n"
      "eor v18.16b, v18.16b, v18.16b \n"
      "eor v19.16b, v19.16b, v19.16b \n"
      "eor v20.16b, v20.16b, v20.16b \n"
      "ld1 {v3.4s}, [%[b_ptr]], #16 \n" # v3加载B的四个字节,同时将指针向后移16个字节
      "eor v21.16b, v21.16b, v21.16b \n"
      "eor v22.16b, v22.16b, v22.16b \n"
      "eor v23.16b, v23.16b, v23.16b \n"
      "ld1 {v4.4s}, [%[b_ptr]], #16 \n" # v4加载B的四个字节,同时将指针向后移16个字节
      "eor v24.16b, v24.16b, v24.16b \n"
      "eor v25.16b, v25.16b, v25.16b \n"
      "eor v26.16b, v26.16b, v26.16b \n"
      "eor v27.16b, v27.16b, v27.16b \n"
      "eor v28.16b, v28.16b, v28.16b \n"
      "ld1 {v0.4s}, [%[a_ptr]], #16 \n"# v0加载A的四个字节,同时将指针向后移16个字节
      "eor v29.16b, v29.16b, v29.16b \n"
      "eor v30.16b, v30.16b, v30.16b \n"
      "eor v31.16b, v31.16b, v31.16b \n"
      "2: \n"
      "cmp %w[K], #0\n" #如果k=0 此时k=1
      "beq 4f\n"

      "3:\n"
      "fmla v8.4s, v0.4s, v2.s[0]\n" # A的前四个元素和B的第一个元素相乘
      "fmla v9.4s, v0.4s, v2.s[1]\n" # A的前四个元素和B的第二个元素相乘
      "ld1 {v1.4s}, [%[a_ptr]], 16\n" # v1加载A的四个字节,同时指针偏移
      "fmla v10.4s, v0.4s, v2.s[2]\n" # A的前四个元素和B的第三个元素相乘
      "fmla v11.4s, v0.4s, v2.s[3]\n" # A的前四个元素和B的第四个元素相乘
      "ld1 {v5.4s}, [%[b_ptr]], #16\n" # v5加载B的四个字节,同时指针偏移
      "fmla v12.4s, v0.4s, v3.s[0]\n"
      "fmla v13.4s, v0.4s, v3.s[1]\n"
      "ld1 {v6.4s}, [%[b_ptr]], #16\n" # v6加载B的四个字节,同时指针偏移
      "fmla v14.4s, v0.4s, v3.s[2]\n"
      "fmla v15.4s, v0.4s, v3.s[3]\n"
      "ld1 {v7.4s}, [%[b_ptr]], #16\n" # v7加载B的四个字节,同时指针偏移
      "fmla v16.4s, v0.4s, v4.s[0]\n"
      "fmla v17.4s, v0.4s, v4.s[1]\n"
      "fmla v18.4s, v0.4s, v4.s[2]\n"
      "fmla v19.4s, v0.4s, v4.s[3]\n"
      "ld1 {v0.4s}, [%[a_ptr]], 16\n" # 此时A的前四个元素的计算完成,向后加载4个字节

      "fmla v20.4s, v1.4s, v2.s[0]\n"
      "fmla v21.4s, v1.4s, v2.s[1]\n"
      "fmla v22.4s, v1.4s, v2.s[2]\n"
      "fmla v23.4s, v1.4s, v2.s[3]\n"
      "fmla v24.4s, v1.4s, v3.s[0]\n"
      "fmla v25.4s, v1.4s, v3.s[1]\n"
      "fmla v26.4s, v1.4s, v3.s[2]\n"
      "fmla v27.4s, v1.4s, v3.s[3]\n"
      "fmla v28.4s, v1.4s, v4.s[0]\n"
      "fmla v29.4s, v1.4s, v4.s[1]\n"
      "fmla v30.4s, v1.4s, v4.s[2]\n"
      "fmla v31.4s, v1.4s, v4.s[3]\n"

      "fmla v8.4s, v0.4s, v5.s[0]\n"
      "fmla v9.4s, v0.4s, v5.s[1]\n"
      "ld1 {v1.4s}, [%[a_ptr]], 16\n" # 此时A的第一个8x1的矩阵计算完毕,向后加载4个字节
      "fmla v10.4s, v0.4s, v5.s[2]\n"
      "fmla v11.4s, v0.4s, v5.s[3]\n"
      "ld1 {v2.4s}, [%[b_ptr]], 16\n"
      "fmla v12.4s, v0.4s, v6.s[0]\n"
      "fmla v13.4s, v0.4s, v6.s[1]\n"
      "ld1 {v3.4s}, [%[b_ptr]], 16\n"
      "fmla v14.4s, v0.4s, v6.s[2]\n"
      "fmla v15.4s, v0.4s, v6.s[3]\n"
      "ld1 {v4.4s}, [%[b_ptr]], 16\n" # B也向后加载一个1x12的矩阵分块
      "fmla v16.4s, v0.4s, v7.s[0]\n"
      "fmla v17.4s, v0.4s, v7.s[1]\n"
      "fmla v18.4s, v0.4s, v7.s[2]\n"
      "fmla v19.4s, v0.4s, v7.s[3]\n"
      "ld1 {v0.4s}, [%[a_ptr]], 16\n" # 同上

      "fmla v20.4s, v1.4s, v5.s[0]\n"
      "fmla v21.4s, v1.4s, v5.s[1]\n"
      "fmla v22.4s, v1.4s, v5.s[2]\n"
      "fmla v23.4s, v1.4s, v5.s[3]\n"
      "fmla v24.4s, v1.4s, v6.s[0]\n"
      "subs %w[K], %w[K], #1\n" #处理K的大小
      "fmla v25.4s, v1.4s, v6.s[1]\n"
      "fmla v26.4s, v1.4s, v6.s[2]\n"
      "fmla v27.4s, v1.4s, v6.s[3]\n"
      "fmla v28.4s, v1.4s, v7.s[0]\n"
      "fmla v29.4s, v1.4s, v7.s[1]\n"
      "fmla v30.4s, v1.4s, v7.s[2]\n"
      "fmla v31.4s, v1.4s, v7.s[3]\n"

      "bne 3b\n"

      # 后面是一些处理尾部的代码块,此处略
作者

lcy

发布于

2024-07-19

更新于

2024-07-19

许可协议