目次: ベンチマーク
お手軽最適化のメモ、昨日の続きです。行列の掛け算を題材にします。前回は行列の掛け算と素朴な実装のコードを紹介しました。今回はお手軽最適化を紹介します。
スカラー処理だと遅いけれど、お手軽に最適化(数倍程度)がしたいときの参考になれば幸いです。
素朴版のコードですとi, j, kの順でループになっていて、kを最内ループにしていました。ループ内の計算は、
c[i * nn + j] += a[i * kk + k] * b[k * nn + j]
でした。このときメモリアクセスのパターンは、
です。GCCの自動ベクトル化ですと、このアクセスパターンをうまく最適化できないようです。対象が最内ループのみなのかもしれません。ループを入れ替えjを最内にして、BとCを行方向に読むようにします。するとメモリアクセスのパターンは、
です。ループを入れ替えるとループ内でCの0初期化ができないので、ループの外に追い出して最終的に下記のようなコードになります。
void sgemm_inner(const float *a, const float *b, float *c, int mm, int nn, int kk)
{
for (int i = 0; i < mm; i++) {
for (int j = 0; j < nn; j++) {
c[i * nn + j] = 0.0f;
}
}
for (int i = 0; i < mm; i++) {
for (int k = 0; k < kk; k++) {
for (int j = 0; j < nn; j++) { //★★jが最内ループ★★
c[i * nn + j] += a[i * kk + k] * b[k * nn + j];
}
}
}
}
$ gcc -Wall -g -O2 -fno-tree-vectorize -static -march=znver3 sgemm.c $ ./a.out matrix size: M:1519, N:1517, K:1523 time: 1.528314 (参考: 素朴版の実行時間) time: 2.277758 (参考: OpenBLASシングルスレッドの実行時間) $ export OPENBLAS_NUM_THREADS=1 ----- use CBLAS verify: 0.052149
ループ入れ替えで倍くらい速くなっていますが、この最適化の本領はコンパイラの自動ベクトル化です。GCCならば -ftree-vectorizeオプションを指定すると、行列Bと行列CへのアクセスにSIMD命令を使うようになります。
Ryzen 7 5700Xの場合はAVX2命令を使えます。他のCPUをお使いの場合は -marchを適宜変更してください。
$ gcc -Wall -g -O2 -ftree-vectorize -static -march=znver3 sgemm.c $ ./a.out matrix size: M:1519, N:1517, K:1523 time: 0.181133 (参考: OpenBLASシングルスレッドの実行時間) $ export OPENBLAS_NUM_THREADS=1 ----- use CBLAS verify: 0.052149
素朴版とOpenBLASでは1/43もの差がありましたが、ループ入れ替えと自動ベクトル化によってOpenBLASの1/3.5程度まで近づきました。GEMMが計算偏重の処理で最適化の効果が出やすい、という点を考慮する必要はあるものの僅かな書き換えで得られる効果にしては割と良いのではないでしょうか。
ソースコードを書き換える元気があれば、別の最適化方法もあります。GEMM特有の話に近付いてしまい、汎用的な話から遠ざかりますが、最適化ポイントの例という意味では参考になるはず……です。たぶん。
今回はSIMD命令でjの方向に一気に読むまでは同じですが、jの方向に進めるのではなく、kの方向に進めて、計算結果をCに足していく戦略です。具体的に言えばAi,kとBk,j〜Bk,j+7の8要素を一気に掛け算してCi,j〜Ci,j+7へ一気に足します。なぜ8要素かというとAVX/AVX2のレジスタ長(256bit)を使うと、32bit長のfloatを一度に8要素処理できるためです。
SIMD命令にはAi,kをSIMDレジスタの全要素に配る命令(set1_psから生成されるbroadcast命令)や、掛け算と足し算を一度に行うfmadd命令など、この計算順に最適な命令が揃っています。
AVX/AVX2のIntrinsicsの詳細についてはIntelのサイトなどを見ていただくとして(Intel Intrinsics Guide)、コードは下記のようになります。
#include <immintrin.h>
void sgemm_avx(const float *a, const float *b, float *c, int mm, int nn, int kk)
{
for (int j = 0; j < nn;) {
if (nn - j >= 8) {
for (int i = 0; i < mm; i++) {
__m256 vc = _mm256_set1_ps(0.0f);
for (int k = 0; k < kk; k++) {
__m256 va = _mm256_set1_ps(a[i * kk + k]);
__m256 vb = _mm256_loadu_ps(&b[k * nn + j]);
vc = _mm256_fmadd_ps(va, vb, vc);
}
_mm256_storeu_ps(&c[i * nn + j], vc);
}
j += 8;
} else {
for (int i = 0; i < mm; i++) {
c[i * nn + j] = 0.0f;
for (int k = 0; k < kk; k++) {
c[i * nn + j] += a[i * kk + k] * b[k * nn + j];
}
}
j++;
}
}
}
$ gcc -Wall -g -O2 -static -march=znver3 sgemm.c $ ./a.out matrix size: M:1519, N:1517, K:1523 time: 0.384861 (参考: 素朴版の実行時間) time: 2.277758 (参考: ループ入れ替え版+自動ベクトル化の実行時間) time: 0.181133 (参考: OpenBLASシングルスレッドの実行時間) $ export OPENBLAS_NUM_THREADS=1 ----- use CBLAS verify: 0.052149
素朴版と比べると6倍速いですが、ループ入れ替え+自動ベクトル化には負けています。
先程のコードはSIMDレジスタを3個しか同時に使っていませんでした。AVX/AVX2のYMMレジスタは16個もあるのに3個しか使わないのはもったいですから、iのループを8要素ずつアンローリングしてSIMDレジスタを同時にたくさん使いましょう。レジスタをうまく使いまわせば12要素のアンローリング(Bの保持に1個、Cの保持に12個、Aの保持に1個、計14個)まではできそうです。たぶん。
#include <immintrin.h>
void sgemm_avx_unroll8(const float *a, const float *b, float *c, int mm, int nn, int kk)
{
for (int j = 0; j < nn;) {
if (nn - j >= 8) {
int i = 0;
for (; i < (mm & ~7); i += 8) {
__m256 vc0 = _mm256_set1_ps(0.0f);
__m256 vc1 = _mm256_set1_ps(0.0f);
__m256 vc2 = _mm256_set1_ps(0.0f);
__m256 vc3 = _mm256_set1_ps(0.0f);
__m256 vc4 = _mm256_set1_ps(0.0f);
__m256 vc5 = _mm256_set1_ps(0.0f);
__m256 vc6 = _mm256_set1_ps(0.0f);
__m256 vc7 = _mm256_set1_ps(0.0f);
for (int k = 0; k < kk; k++) {
__m256 vb = _mm256_loadu_ps(&b[k * nn + j]);
__m256 va0 = _mm256_set1_ps(a[(i+0) * kk + k]);
__m256 va1 = _mm256_set1_ps(a[(i+1) * kk + k]);
__m256 va2 = _mm256_set1_ps(a[(i+2) * kk + k]);
__m256 va3 = _mm256_set1_ps(a[(i+3) * kk + k]);
__m256 va4 = _mm256_set1_ps(a[(i+4) * kk + k]);
__m256 va5 = _mm256_set1_ps(a[(i+5) * kk + k]);
__m256 va6 = _mm256_set1_ps(a[(i+6) * kk + k]);
__m256 va7 = _mm256_set1_ps(a[(i+7) * kk + k]);
vc0 = _mm256_fmadd_ps(va0, vb, vc0);
vc1 = _mm256_fmadd_ps(va1, vb, vc1);
vc2 = _mm256_fmadd_ps(va2, vb, vc2);
vc3 = _mm256_fmadd_ps(va3, vb, vc3);
vc4 = _mm256_fmadd_ps(va4, vb, vc4);
vc5 = _mm256_fmadd_ps(va5, vb, vc5);
vc6 = _mm256_fmadd_ps(va6, vb, vc6);
vc7 = _mm256_fmadd_ps(va7, vb, vc7);
}
_mm256_storeu_ps(&c[(i+0) * nn + j], vc0);
_mm256_storeu_ps(&c[(i+1) * nn + j], vc1);
_mm256_storeu_ps(&c[(i+2) * nn + j], vc2);
_mm256_storeu_ps(&c[(i+3) * nn + j], vc3);
_mm256_storeu_ps(&c[(i+4) * nn + j], vc4);
_mm256_storeu_ps(&c[(i+5) * nn + j], vc5);
_mm256_storeu_ps(&c[(i+6) * nn + j], vc6);
_mm256_storeu_ps(&c[(i+7) * nn + j], vc7);
}
for (; i < mm; i++) {
__m256 vc = _mm256_set1_ps(0.0f);
for (int k = 0; k < kk; k++) {
__m256 vb = _mm256_loadu_ps(&b[k * nn + j]);
__m256 va = _mm256_broadcast_ss(&a[(i+0) * kk + k]);
vc = _mm256_fmadd_ps(va, vb, vc);
}
_mm256_storeu_ps(&c[i * nn + j], vc);
}
j += 8;
} else {
for (int i = 0; i < mm; i++) {
c[i * nn + j] = 0.0f;
for (int k = 0; k < kk; k++) {
c[i * nn + j] += a[i * kk + k] * b[k * nn + j];
}
}
j++;
}
}
}
$ ./a.out matrix size: M:1519, N:1517, K:1523 time: 0.108420 (参考: ループ入れ替え版+自動ベクトル化の実行時間) time: 0.181133 (参考: OpenBLASシングルスレッドの実行時間) $ export OPENBLAS_NUM_THREADS=1 ----- use CBLAS verify: 0.052149
もはや最適化前のコードの原型がありませんが、ループ入れ替え版+自動ベクトル化の1.7倍くらいの速度になりました。OpenBLASの1/2程度まで迫っています。この最適化手法が汎用的か?と聞かれると何とも言えないですが、SIMDレジスタを同時にたくさん使う、最内ループ以外もアンローリング(最内ループはコンパイラがやってくれる)辺りは割と汎用的なアイデアです。
あと前回言った通り、GEMMは最適化の題材として取り上げただけなので、実際にGEMMを計算する場合はこのコードや自分で書いたコードを使うのではなく、信頼と実績のOpenBLASを使ってくださいませ。
目次: ベンチマーク
お手軽最適化のメモです。行列の掛け算を題材にします。
最初にお断りしておくとGEMMのような汎用処理の場合は、自分で最適化せずにOpenBLASを使ってください(素人が最適化しても勝てません)。しかしOpenBLASのような限界まで最適化されたコードは誰でも簡単には書ける、とは言えません。
スカラー処理だと遅いけれど、お手軽に最適化(数倍程度)がしたいときの参考になれば幸いです。
GEMMはGeneral Matrix Multiplyの略で、高校数学辺りでやった(はず)の行列の掛け算のことです。floatの場合はSGEMMと呼ばれ、doubleの場合はDGEMMと呼ばれます。最適化の題材はどちらでも良いんですけど、今回はSGEMMを使います。
忘れている方のために2行3列の行列Aと3行2列Bの掛け算A x B = Cだとこんな感じです。
Aの列数とBの行数は一致していなければなりません。行数と列数の関係を表すと、行列A(M行K列)x行列B(K行N列)= 行列C(M行N列)となります。
Cの1要素を計算するには、Aの1行とBの1列が必要です。式、および、視覚的に示すと下記のようになります。
説明はこれくらいにしてコードを見ましょう。
SGEMMを素直にコードにするとこんな感じです。行方向にデータを格納(Row-major orderといいます)しているので、N列の行列Cのi行j列(以降Ci,jと書く)にアクセスする際はc[i * N + j] とします。
void sgemm_naive(const float *a, const float *b, float *c, int mm, int nn, int kk)
{
for (int i = 0; i < mm; i++) {
for (int j = 0; j < nn; j++) {
c[i * nn + j] = 0.0f;
for (int k = 0; k < kk; k++) {
c[i * nn + j] += a[i * kk + k] * b[k * nn + j];
}
}
}
}
行列のサイズを適当に設定(M = 1519, N = 1517, K = 1523)して、実行時間を測ります。CPUはRyzen 7 5700Xです。
$ gcc -Wall -g -O2 -static sgemm.c $ ./a.out matrix size: M:1519, N:1517, K:1523 time: 2.277758
実行時間は実行するたびに変わりますが、大体2.27秒くらいでしょうか。OpenBLASのシングルスレッド(環境変数OPENBLAS_NUM_THREADS=1にするとシングルスレッド動作になります)で計算した時間を見ると、
c_ex = malloc(m * n * sizeof(float) * 2);
// C = alpha AB + beta C
float alpha = 1.0f, beta = 0.0f;
int lda = k, ldb = n, ldc = n;
printf("----- use CBLAS\n");
gettimeofday(&st, NULL);
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
m, n, k, alpha, a, lda, b, ldb, beta, c_ex, ldc);
gettimeofday(&ed, NULL);
timersub(&ed, &st, &ela);
printf("verify: %d.%06d\n", (int)ela.tv_sec, (int)ela.tv_usec);
$ export OPENBLAS_NUM_THREADS=1 $ gcc -Wall -g -O2 -static -L path/to/openblas/ sgemm.c -lopenblas $ ./a.out matrix size: M:1519, N:1517, K:1523 ----- use CBLAS verify: 0.052149
わずか0.05秒、実に43倍という驚異のスピードです。すごいですね……。
行列の掛け算の説明でほぼ終わってしまいました。お手軽最適化はまた次に。
東京の道はようわからん通称(明治通り、みたいなやつ)が付いています。東京育ちの人にはなじみ深い名前だと思うんですが、都外出身者からすると、東京の道の通称と国道何号線、都道何号線の区分が全く合っていないので、知らない場所の道の通称を言われると結構困ります。
なぜかというと国道何号線、都道何号線という表記は地図で省かれることはない一方で、東京の道の通称は高速道路や地下鉄と重なったときに省かれてしまうことがあるためです。地図を見ても「〇〇通り?何それ?どこ??」と困惑します。
東京の道の通称は東京都が決めています(東京都通称道路名〜道路のわかりやすく親しみやすい名称〜 - 東京都建設局)。道案内の青看板に載るような名前と思ってもらえばわかりやすいでしょう。
東京都がまとめてくれているのはありがたいんですが「東京都が公式に定めた通称」というのは不思議な響きで、それは公称ではなかろうか?通称とは一体……??
東京の道の通称には「年号」+「通り」という名前がいくつかあります。このうちなぜか大正通りだけは存在しません。不思議ですね。
調べてみると割と有名な話らしいです(「明治通り」「昭和通り」はあるのに、なぜか「大正通り」はない東京のちょっとした謎 - アーバン ライフ メトロ)。詳しくは記事を読んでいただくとして、簡単に言えば、
ということみたいです。
< | 2023 | > | ||||
<< | < | 02 | > | >> | ||
日 | 月 | 火 | 水 | 木 | 金 | 土 |
- | - | - | 1 | 2 | 3 | 4 |
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | - | - | - | - |
合計:
本日: