如何优化gemm算子
最近在做开源之下的项目,记录一下学习 gemm 的优化过程。
gemm 即矩阵乘法,通常的 gemm 可表示为一个 MxK 的矩阵 A 和一个 KxN 的矩阵 B 相乘得到一个 MxN 的矩阵 C。矩阵计算的优化思路可以总结为以下几步,在这个博客中省去了对计算机组成原理的一些介绍。
# 优化思路
-
Split MNK
这一步优化主要是基于计算机体系架构的 cache 结构,节省计算数据在 cache 和内存间读取的时间。
切分 MNK(根据 L1 cache 的大小),比如 arm64 中 FP32M8N12 就是将矩阵 A(MxK)切分为若干个(8xK)矩阵,将矩阵 B(KxN)切分为若干个(Kx12)的矩阵,在第 2 步中提到了为何这么切分。 -
分层计算
这一步优化主要是节省中间数据在内存和 cache 之间换入换出的时间。
传统的 MNK 矩阵切分计算后,应该是 A 的小矩阵的每一行乘上 B 的小矩阵的每一列,这种做法比较低效,且存在很多冗余的重复加载。
分层计算的核心是每次用 A 的小矩阵的每一列(8 个元素,1X8 矩阵)和 B 的小矩阵的每一行(12 个元素,12X1 矩阵)相乘,得到中间结果矩阵(一层 8x12 矩阵),在这个过程中,把每一个中间结果矩阵累加,就得到了小矩阵 8K 和小矩阵 K12 计算的结果(完整 8x12 矩阵)。
以 arm64 架构的 CPU 为例子,有 32 个 128 位的 neon 向量寄存器,那么 M8N12 的每次分层计算中,A 需要从内存中读取 8 个元素,占用 2 个寄存器;B 需要从内存中读取 12 个元素,占用 3 个寄存器;C(中间结果)的形状为 8x12,需要占用 24 个寄存器来累加运算。 -
SIMD 优化
这一步优化主要是通过并行计算提高数据运算的速度。在 cache 计算中使用 SIMD 优化,并行计算元素与元素相乘的过程。
-
内存重排(pack)
这一步的优化主要是节省数据在内存中读取的时间。
以 x86 的 CPU 为例,矩阵在内存中是逐行存储的,一个 4x4 的矩阵 A 的存储顺序(由小到大)为| 0 | 1 | 2 | 3 |
| 4 | 5 | 6 | 7 |
| 8 | 9 | A | B |
| C | D | E | F |假设该矩阵被分割成 2 个 2x2 的矩阵,在计算中需要逐列读取,那么第一个被读取的元素是 0 处的,第二个被读取的元素是 4 处的,他们的地址并不连续,这会浪费一定的寻址时间。因此,在 2x2 的分割下,需要针对该矩阵进行内存重排,重排后的顺序如下,这种重排叫做 zigzag,因为是按照之字型排序的。
| 0 | 2 | 4 | 6 |
| 1 | 3 | 5 | 7 |
| 8 | A | C | E |
| 9 | B | D | F |
# Example
以 Megcc 生成的算子 Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS
为例子,分析一下如何实现 gemm 优化过后的算子。这是按照 M8N12K4 的分割的,想象一个三维的长方体分割就可以。
-
重排内存
kernel 的执行入口是
void Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS
,该程序首先确认矩阵的内存空间。-
Code
1
2
3
4
5
6
7size_t pack_a_size = Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS_workspace_a(0, M, 0, K);
float* pack_a = workspace; # 矩阵A的开始地址
float* pack_b = workspace + pack_a_size; #矩阵B的开始地址
Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS_packa_n(pack_a, A, LDA, 0, M, 0, K);
# 按照上述第4步对矩阵进行重排
Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS_packb_n(pack_b, B, LDB, 0, N, 0, K);
# 按照上述第4步对矩阵进行重排
-
-
矩阵计算
矩阵的计算是逐 Block 进行的, M8N12K4 的计算方式中,每一个块的大小为 8x12x4。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46for (; m + m_block <= M; m += m_block) {
float* output = C + (m / pack_mk * LDC);
size_t n = 0;
const float* cur_pack_b = pack_b;
for (; n + n_block <= N; n += n_block) {
kern_8x12_bias_relu(pack_a, cur_pack_b, K, output, LDC, _bias_ptr); # 每一个小block的计算
output += n_block * pack_mk;
;
cur_pack_b += K12;
}
# 当n不能被12整除时,剩下的n按照4分割计算。
for (; n < N; n += 4) {
kern_8x4_bias_relu(
pack_a, cur_pack_b, K, output, LDC, _bias_ptr,
N - n > 4 ? 4 : N - n);
# 即使剩下的不能被4整除也按照4向上对齐处理。
output += 4 * pack_mk;
;
cur_pack_b += K4;
}
pack_a += K8;
}
#当m不能被8整除时,剩下的m的维度按照4整除计算。
for (; m < M; m += m_block_4) {
float* output = C + (m / pack_mk * LDC);
size_t n = 0;
const float* cur_pack_b = pack_b;
for (; n + n_block - 1 < N; n += n_block) {
kern_4x12_bias_relu(pack_a, cur_pack_b, K, output, LDC, _bias_ptr);
output += n_block * pack_mk;
;
cur_pack_b += K12;
}
#同上
for (; n < N; n += 4) {
kern_4x4_bias_relu(
pack_a, cur_pack_b, K, output, LDC, _bias_ptr,
N - n > 4 ? 4 : N - n);
output += 4 * pack_mk;
;
cur_pack_b += K4;
}
pack_a += K4;
} -
汇编实现:内存重排
对 A 进行内存重排的程序入口如下,矩阵 A 的形状为 [M/4, K/4, 4, 4],如下图所示,A 矩阵在内存中首先以块顺序 0-7 存储,在每一块内按照 0-F 的顺序存储。对 A 进行内存重排后的形状为 [M/8, K, 8],在图中为了方便理解,箭头方向为数据存储的方向,假设 M=16,K=8。
以下程序描述了如何进行内存重排的过程。首先需要两个指针
inptr0
和inptr1
定位原矩阵的内存,其中inptr0
定位的是第 0 矩阵块(以下都称为块,描述的是上图中矩阵的分块情况)内存起点,inptr1
定位的是第 2 块内存起点,prefetch_2x
函数的作用是读取 32 个 fp32 数据,即 2 块,该函数将 0、1 和 2、3 块数据加载到缓存中,interleave_2x4_4_s
是重排函数,作用是将载入缓存的块数据重排,具体如何展开请看后续。注意,若 m 不被 8 整除,需要在下一个循环代码中处理尾部数据。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34void Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS_packa_n(
float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0,
int kmax) {
const int pack_mk = 4;
const int pack_m = 8;
const int m_stride = pack_m * pack_mk;
const int min_m_stride = pack_mk * pack_mk;
int y = 0;
for (; y + 7 < ymax; y += pack_m) {
const float* inptr0 = inptr + y / pack_mk * ldin;
const float* inptr1 = inptr0 + ldin;
prefetch_2x(inptr0);
prefetch_2x(inptr1);
int k = (kmax);
for (; k > 3; k -= pack_mk) {
interleave_2x4_4_s(inptr0, inptr1, outptr);
outptr += m_stride;
inptr0 += min_m_stride;
inptr1 += min_m_stride;
}
}
for (; y < ymax; y += pack_mk) {
const float* inptr0 = inptr + y / pack_mk * ldin;
prefetch_2x(inptr0);
int K = (kmax);
for (; K > 3; K -= pack_mk) {
interleave_1x4_4_s(inptr0, outptr);
outptr += min_m_stride;
inptr0 += min_m_stride;
}
}
}重排函数实现的重点是函数
interleave_2x4_4_s
,下图描述了矩阵 A 内存重排的实现过程。其中 v0-v7 为寄存器。程序首先将数据加载到寄存器后,重排顺序后再输出到矩阵 A 的输出位置指针。对 B 实现内存重排的原理类似,其原始形状为 [K/4, N, 4],重排后形状为 [N/12, K, 12],其形状变化和内存排序如下图所示,和 A 相比,需要多进行一个类似转置的操作。如下图所示,该重排把一个 4x12 的矩阵块(第一块绿色,紫色,黄色的三块)重排成了一个 12x4 的矩阵块。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52void Arm64_fp32_m8_n12_mk4_matmul_NO_BIAS_packb_n(
float* outptr, const float* inptr, int ldin, int x0, int xmax, int k0,
int kmax) {
float tmpbuff[16] = {0.0f};
int PACK_C_SIZE = 4;
int ksize = kmax - k0;
int ksize12 = ksize * 12;
int ksize4 = (ksize << 2);
float* outptr_base = outptr;
float* outptr_base4 = outptr_base + (xmax - x0) / 12 * ksize12;
int k = k0;
for (; k + 3 < kmax; k += 4) {
const float* temp_inptr = inptr + k / PACK_C_SIZE * ldin + x0 * PACK_C_SIZE;
prefetch_3x(temp_inptr);
int x = x0;
float* temp_outptr = outptr_base;
for (; x + 12 <= xmax; x += 12) {
float* outptr_interleave = temp_outptr;
transpose_1x12_4_s(temp_inptr, outptr_interleave);
temp_outptr += ksize12;
temp_inptr += 4 * 12;
}
temp_outptr = outptr_base4;
for (; x + 4 <= xmax; x += 4) {
float* outptr_interleave = temp_outptr;
asm volatile(
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr0]]\n"
: [inptr0] "+r"(temp_inptr), [outptr0] "+r"(outptr_interleave)
:
: "v0", "v1", "v2", "v3", "memory");
temp_outptr += ksize4;
}
if (x < xmax) {
memcpy(tmpbuff, temp_inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE);
float* outptr_interleave = temp_outptr;
const float* tmp_ptr = &tmpbuff[0];
asm volatile(
"ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr0]]\n"
: [inptr0] "+r"(tmp_ptr), [outptr0] "+r"(outptr_interleave)
:
: "v0", "v1", "v2", "v3", "memory");
temp_outptr += ksize4;
}
outptr_base += 12 * 4;
outptr_base4 += 4 * 4;
}
} -
汇编实现: kernel func
根据实现细节,实现一个 M8N12K4 的算子需要以下几个分层运算函数:
kern_8x12_bias_relu
,kern_8x4_bias_relu
,kern_4x12_bias_relu
,kern_4x4_bias_relu
, 这些都是基于汇编语言实现的。以主要的kern_8x12_bias_relu
为例子,来看一下这个函数是如何实现分层计算的。
需要理解的汇编语言操作如下表- eor 异或操作
- ld1 从内存中加载数据
- prfm 预取内存到缓存
- cmp 比较操作符
- fmla 向量乘加操作
- bne 比较结果不相等时跳转
- st1 将寄存器的数据写入内存
根据优化思路的第 2 步,首先你需要加载 A 的 8 个元素(2 个寄存器)和 B 的 12 个元素(3 个寄存器)到 cache 中,再为中间结果矩阵准备 24 个寄存器。
在以下这段汇编代码中,首先通过异或操作eor
清空了 8-31 寄存器的所有字节(x.16b),一共 24 个,这些寄存器是为了存储中间结果准备的。用 output0 和 output1 两个指针作为输出,并通过prfm
把他们预取到缓存中。以下是汇编代码的主体实现,实际使用时,arm 架构有 32 个 128 位的寄存器,除去以上的 29 个,另外空闲的 3 个处理器分给 B 做数据的预取。-
Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101"eor v8.16b, v8.16b, v8.16b \n" # 清空v8寄存器,下同
"eor v9.16b, v9.16b, v9.16b \n"
"eor v10.16b, v10.16b, v10.16b \n"
"prfm pstl1keep, [%[output0]] \n" # 预取
"eor v11.16b, v11.16b, v11.16b \n"
"eor v12.16b, v12.16b, v12.16b \n"
"eor v13.16b, v13.16b, v13.16b \n"
"prfm pstl1keep, [%[output1]] \n"
"eor v14.16b, v14.16b, v14.16b \n"
"eor v15.16b, v15.16b, v15.16b \n"
"ld1 {v2.4s}, [%[b_ptr]], #16 \n" # v2加载B的四个字节,同时将指针向后移16个字节
"eor v16.16b, v16.16b, v16.16b \n"
"eor v17.16b, v17.16b, v17.16b \n"
"eor v18.16b, v18.16b, v18.16b \n"
"eor v19.16b, v19.16b, v19.16b \n"
"eor v20.16b, v20.16b, v20.16b \n"
"ld1 {v3.4s}, [%[b_ptr]], #16 \n" # v3加载B的四个字节,同时将指针向后移16个字节
"eor v21.16b, v21.16b, v21.16b \n"
"eor v22.16b, v22.16b, v22.16b \n"
"eor v23.16b, v23.16b, v23.16b \n"
"ld1 {v4.4s}, [%[b_ptr]], #16 \n" # v4加载B的四个字节,同时将指针向后移16个字节
"eor v24.16b, v24.16b, v24.16b \n"
"eor v25.16b, v25.16b, v25.16b \n"
"eor v26.16b, v26.16b, v26.16b \n"
"eor v27.16b, v27.16b, v27.16b \n"
"eor v28.16b, v28.16b, v28.16b \n"
"ld1 {v0.4s}, [%[a_ptr]], #16 \n"# v0加载A的四个字节,同时将指针向后移16个字节
"eor v29.16b, v29.16b, v29.16b \n"
"eor v30.16b, v30.16b, v30.16b \n"
"eor v31.16b, v31.16b, v31.16b \n"
"2: \n"
"cmp %w[K], #0\n" #如果k=0 此时k=1
"beq 4f\n"
"3:\n"
"fmla v8.4s, v0.4s, v2.s[0]\n" # A的前四个元素和B的第一个元素相乘
"fmla v9.4s, v0.4s, v2.s[1]\n" # A的前四个元素和B的第二个元素相乘
"ld1 {v1.4s}, [%[a_ptr]], 16\n" # v1加载A的四个字节,同时指针偏移
"fmla v10.4s, v0.4s, v2.s[2]\n" # A的前四个元素和B的第三个元素相乘
"fmla v11.4s, v0.4s, v2.s[3]\n" # A的前四个元素和B的第四个元素相乘
"ld1 {v5.4s}, [%[b_ptr]], #16\n" # v5加载B的四个字节,同时指针偏移
"fmla v12.4s, v0.4s, v3.s[0]\n"
"fmla v13.4s, v0.4s, v3.s[1]\n"
"ld1 {v6.4s}, [%[b_ptr]], #16\n" # v6加载B的四个字节,同时指针偏移
"fmla v14.4s, v0.4s, v3.s[2]\n"
"fmla v15.4s, v0.4s, v3.s[3]\n"
"ld1 {v7.4s}, [%[b_ptr]], #16\n" # v7加载B的四个字节,同时指针偏移
"fmla v16.4s, v0.4s, v4.s[0]\n"
"fmla v17.4s, v0.4s, v4.s[1]\n"
"fmla v18.4s, v0.4s, v4.s[2]\n"
"fmla v19.4s, v0.4s, v4.s[3]\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n" # 此时A的前四个元素的计算完成,向后加载4个字节
"fmla v20.4s, v1.4s, v2.s[0]\n"
"fmla v21.4s, v1.4s, v2.s[1]\n"
"fmla v22.4s, v1.4s, v2.s[2]\n"
"fmla v23.4s, v1.4s, v2.s[3]\n"
"fmla v24.4s, v1.4s, v3.s[0]\n"
"fmla v25.4s, v1.4s, v3.s[1]\n"
"fmla v26.4s, v1.4s, v3.s[2]\n"
"fmla v27.4s, v1.4s, v3.s[3]\n"
"fmla v28.4s, v1.4s, v4.s[0]\n"
"fmla v29.4s, v1.4s, v4.s[1]\n"
"fmla v30.4s, v1.4s, v4.s[2]\n"
"fmla v31.4s, v1.4s, v4.s[3]\n"
"fmla v8.4s, v0.4s, v5.s[0]\n"
"fmla v9.4s, v0.4s, v5.s[1]\n"
"ld1 {v1.4s}, [%[a_ptr]], 16\n" # 此时A的第一个8x1的矩阵计算完毕,向后加载4个字节
"fmla v10.4s, v0.4s, v5.s[2]\n"
"fmla v11.4s, v0.4s, v5.s[3]\n"
"ld1 {v2.4s}, [%[b_ptr]], 16\n"
"fmla v12.4s, v0.4s, v6.s[0]\n"
"fmla v13.4s, v0.4s, v6.s[1]\n"
"ld1 {v3.4s}, [%[b_ptr]], 16\n"
"fmla v14.4s, v0.4s, v6.s[2]\n"
"fmla v15.4s, v0.4s, v6.s[3]\n"
"ld1 {v4.4s}, [%[b_ptr]], 16\n" # B也向后加载一个1x12的矩阵分块
"fmla v16.4s, v0.4s, v7.s[0]\n"
"fmla v17.4s, v0.4s, v7.s[1]\n"
"fmla v18.4s, v0.4s, v7.s[2]\n"
"fmla v19.4s, v0.4s, v7.s[3]\n"
"ld1 {v0.4s}, [%[a_ptr]], 16\n" # 同上
"fmla v20.4s, v1.4s, v5.s[0]\n"
"fmla v21.4s, v1.4s, v5.s[1]\n"
"fmla v22.4s, v1.4s, v5.s[2]\n"
"fmla v23.4s, v1.4s, v5.s[3]\n"
"fmla v24.4s, v1.4s, v6.s[0]\n"
"subs %w[K], %w[K], #1\n" #处理K的大小
"fmla v25.4s, v1.4s, v6.s[1]\n"
"fmla v26.4s, v1.4s, v6.s[2]\n"
"fmla v27.4s, v1.4s, v6.s[3]\n"
"fmla v28.4s, v1.4s, v7.s[0]\n"
"fmla v29.4s, v1.4s, v7.s[1]\n"
"fmla v30.4s, v1.4s, v7.s[2]\n"
"fmla v31.4s, v1.4s, v7.s[3]\n"
"bne 3b\n"
# 后面是一些处理尾部的代码块,此处略