* limitations under the License.
*******************************************************************************/
-#include <math.h>
+#include <cmath>
+#include <mutex>
#include "mkldnn_thread.hpp"
#include "utils.hpp"
-#include "gemm_utils.hpp"
+#include "ref_gemm_f32.hpp"
+#include "gemm_utils_f32.hpp"
#include "jit_avx512_common_gemm_f32.hpp"
-#define CACHE_LINE_SIZE 64
+#include "jit_generator.hpp"
namespace mkldnn {
namespace impl {
namespace cpu {
-using namespace mkldnn::impl::memory_format;
-using namespace mkldnn::impl::utils;
+#define CACHE_LINE_SIZE 64
-using namespace Xbyak;
#define STACKSIZE get_size_of_abi_save_regs()
#ifdef _WIN32
#define STACK_K_CAPACITY 32
#define UNROLL_M 48
#define UNROLL_N 8
-struct jit_avx512_common_gemm_f32::xbyak_gemm : public jit_generator {
- xbyak_gemm(char transa, char transb, float beta, bool hasBias = false,
+namespace avx512_common_gemm_f32 {
+using namespace gemm_utils;
+
+struct xbyak_gemm : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm)
+
+ xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false,
void *code_ptr = nullptr,
size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
: jit_generator(code_ptr, code_size)
{
+ using namespace Xbyak;
+
enum { ver_avx512_core, ver_avx512_mic } ver =
mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic;
- bool isTransA = (transa == 'T' || transa == 't');
- bool isTransB = (transb == 'T' || transb == 't');
bool isBeta0 = (beta == 0.0);
bool isBetaN = (!isBeta0 && beta != 1.0);
vzeroupper();
postamble();
- ker_ = reinterpret_cast<decltype(ker_)>(
- const_cast<uint8_t *>(this->getCode()));
+ ker_ = this->getCode<ker_t>();
}
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm)
+ typedef void (*ker_t)(dim_t m, dim_t n, dim_t k,
+ const float *alpha, const float *a, dim_t lda,
+ const float *b, dim_t ldb, const float *beta, float *c,
+ dim_t ldc, const float *bias, float *ws);
- void operator()(long long int m, long long int n, long long int k,
- const float *alpha, const float *a, long long int lda,
- const float *b, long long int ldb, const float *beta, float *c,
- long long int ldc, const float *bias, float *ws)
+ void operator()(dim_t m, dim_t n, dim_t k,
+ const float *alpha, const float *a, dim_t lda,
+ const float *b, dim_t ldb, const float *beta, float *c,
+ dim_t ldc, const float *bias, float *ws) const
{
- (*ker_)(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
+ ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
}
private:
- void (*ker_)(long long int m, long long int n, long long int k,
- const float *alpha, const float *a, long long int lda,
- const float *b, long long int ldb, const float *beta, float *c,
- long long int ldc, const float *bias, float *ws);
+ ker_t ker_;
};
-typedef void (*ker)(long long int, long long int, long long int, float *,
- float *, long long int, float *, long long int, float *, float *,
- long long int, float *, float *);
-void jit_avx512_common_gemm_f32::sgemm_nocopy_driver(const char *transa,
+const xbyak_gemm *get_xbyak_gemm(
+ bool isTransA, bool isTransB, float beta, bool hasBias) {
+ auto beta_idx = [](float beta) {
+ return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2);
+ };
+
+ // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)]
+ static xbyak_gemm *kernel_table[2][2][2][3];
+ static std::once_flag initialized;
+ std::call_once(initialized, [=]{
+ for (bool isTransA: {false, true})
+ for (bool isTransB: {false, true})
+ for (bool hasBias: {false, true})
+ for (float beta: {0.0f, 1.0f, 2.0f}) {
+ // nocopy sgemm with bias for beta != 0.0 is not supported
+ if (hasBias && beta != 0.0)
+ continue;
+ kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] =
+ new xbyak_gemm(isTransA, isTransB, beta, hasBias);
+ }
+ });
+
+ return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)];
+}
+
+void sgemm_nocopy_driver(const char *transa,
const char *transb, int m, int n, int k, const float *alpha,
- const float *a, int lda, const float *b, int ldb, const float *beta,
- float *c, int ldc, const float *bias, float *ws)
+ const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta,
+ float *c, dim_t ldc, const float *bias, float *ws)
{
bool isTransA = (*transa == 'T' || *transa == 't');
bool isTransB = (*transb == 'T' || *transb == 't');
return;
}
+ assert(IMPLICATION(bias != nullptr, *beta == 0.0));
+
+ // XXX: this happens on every thread...
+ bool hasBias = (bias != nullptr);
+ auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias);
+ auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false);
+ auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false);
+ assert(ker_bn && ker_b1 && ker_b0);
+
int BM = 4032, BN, BK;
if (mayiuse(avx512_core)) {
BN = isTransA ? 384 : 64;
}
if (!isTransA) {
- curA = a + Bm + (size_t)Bk * lda;
+ curA = a + Bm + Bk * lda;
} else {
- curA = a + Bk + (size_t)Bm * lda;
+ curA = a + Bk + Bm * lda;
}
if (!isTransB) {
- curB = b + Bk + (size_t)Bn * ldb;
+ curB = b + Bk + Bn * ldb;
} else {
- curB = b + Bn + (size_t)Bk * ldb;
+ curB = b + Bn + Bk * ldb;
}
curC = c + Bm + (size_t)Bn * ldc;
if (bias != nullptr) {
}
if (Bk == 0) {
if (*beta == 0.0 && bias == nullptr)
- (*ker_b0_)((long long int)sizeM, (long long int)sizeN,
- (long long int)sizeK, alpha, curA,
- (long long int)lda, curB, (long long int)ldb,
- beta, curC, (long long int)ldc, curBias, ws);
+ (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
else
- (*ker_bn_)((long long int)sizeM, (long long int)sizeN,
- (long long int)sizeK, alpha, curA,
- (long long int)lda, curB, (long long int)ldb,
- beta, curC, (long long int)ldc, curBias, ws);
+ (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
} else {
- (*ker_b1_)((long long int)sizeM, (long long int)sizeN,
- (long long int)sizeK, alpha, curA,
- (long long int)lda, curB, (long long int)ldb, beta,
- curC, (long long int)ldc, curBias, ws);
+ (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
}
}
}
}
- return;
}
-void jit_avx512_common_gemm_f32::sgemm(const char *transa, const char *transb,
+}
+
+mkldnn_status_t jit_avx512_common_gemm_f32(
+ const char *transa, const char *transb,
const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
const float *A, const int *p_lda, const float *B, const int *p_ldb,
const float *p_beta, float *C, const int *p_ldc, const float *bias)
{
- if (beta_ == 0. || beta_ == 1.)
- assert(*p_beta == beta_);
- assert((one_of(*transa, 'T', 't') == one_of(transa_, 'T', 't')));
+ using namespace mkldnn::impl::utils;
+ using namespace avx512_common_gemm_f32;
+ using namespace gemm_utils;
+
+ if (*p_beta != 0 && bias)
+ return ref_gemm(transa, transb, p_m, p_n, p_k,
+ p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias);
int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
+
int m = *p_m;
int n = *p_n;
int k = *p_k;
- int lda = *p_lda;
- int ldb = *p_ldb;
- int ldc = *p_ldc;
+ dim_t lda = *p_lda;
+ dim_t ldb = *p_ldb;
+ dim_t ldc = *p_ldc;
float beta = *p_beta;
int MB, NB, KB;
int nthr_m, nthr_n, nthr_k, nthr_mn;
- assert(nthr <= nthrs_);
-
// Determine threading partitioning
- gemm_utils::calc_nthr_nocopy_avx512_common(
+ calc_nthr_nocopy_avx512_common(
m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
CACHE_LINE_SIZE);
ompstatus = (unsigned char volatile *) ompstatus_;
assert(ompstatus);
+
for (int i = 0; i < nthr; i++)
ompstatus[i * CACHE_LINE_SIZE] = 0;
* sizeof(float), PAGE_4K);
}
- const size_t ws_elems_per_thr = k * 48 + 64;
+ const size_t ws_elems_per_thr = (size_t)k * 48 + 64;
const size_t ws_size_per_thr
- = utils::rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
+ = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
if (k > STACK_K_CAPACITY) {
ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
}
- parallel(nthr, [&](const int ithr, const int nthr) {
+ parallel_nd(nthr, [&](const int ithr) {
int ithr_m, ithr_n, ithr_k, ithr_mn;
int m_from, m_to, myM;
int n_from, n_to, myN;
float *myC = C, myBeta;
float *ws = ws_buffers ?
ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
- int ld = ldc;
+ dim_t ld = ldc;
+
+ int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k);
if (ithr < nthr_m * nthr_n * nthr_k) {
myC = &(C[m_from + n_from * ldc]);
myBeta = beta;
ld = ldc;
- if (hasBias_)
+ if (bias)
myBias = &(bias[m_from]);
} else {
- myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
myBeta = 0.0;
ld = MB;
myBias = nullptr;
sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
- if (nthr_k > 1)
+ if (nthr_k > 1 && !sum_later)
ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
}
- if (nthr_k > 1) {
+ if (nthr_k > 1 && !sum_later) {
// sum matrices partitioned along K dimension
int n1, n2;
- gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
+ partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
if (ithr_k > 0) {
- myC = c_buffers + MB * NB * (cbase + ithr_k - 1);
- myC = myC + n1 * MB;
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
+ + (dim_t)n1 * MB;
/* need to wait until main thread finishes */
while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
};
/* my cache is hot */
- gemm_utils::sum_two_matrices(myM, n2, myC, MB,
+ sum_two_matrices(myM, n2, myC, MB,
&C[m_from + (n_from + n1) * ldc], ldc);
}
for (int ik = 1; ik < nthr_k; ++ik) {
if (ik != ithr_k) {
- myC = c_buffers + MB * NB * (cbase + ik - 1);
- myC = myC + n1 * MB;
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
+ + (dim_t)n1 * MB;
while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
};
- gemm_utils::sum_two_matrices(myM, n2, myC, MB,
+ sum_two_matrices(myM, n2, myC, MB,
&C[m_from + (n_from + n1) * ldc], ldc);
}
}
}
});
+
+ // handle C summation later
+ if (nthr_k > 1 && ompstatus[0] == 0) {
+
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_m, ithr_n, ithr_k, ithr_mn;
+ int m_from, m_to, myM;
+ int n_from, n_to, myN;
+ int cbase;
+ float *myC = C;
+
+ if (ithr < nthr_m * nthr_n * nthr_k) {
+
+ ithr_mn = ithr % nthr_mn;
+ ithr_m = ithr_mn % nthr_m;
+ ithr_n = ithr_mn / nthr_m;
+ ithr_k = ithr / nthr_mn;
+
+ /* swap ithr_k for performance improvement */
+ if (ithr_k == 0)
+ ithr_k = nthr_k - 1;
+ else if (ithr_k == nthr_k - 1)
+ ithr_k = 0;
+
+ m_from = MB * (ithr_m);
+ m_to = MB * (ithr_m + 1);
+ if (m_to > m)
+ m_to = m;
+ myM = m_to - m_from;
+
+ n_from = NB * (ithr_n);
+ n_to = NB * (ithr_n + 1);
+ if (n_to > n)
+ n_to = n;
+ myN = n_to - n_from;
+
+ cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+
+ if (nthr_k > 1) {
+ // sum matrices partitioned along K dimension
+ int n1, n2;
+
+ partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
+
+ if (ithr_k > 0) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
+ + (dim_t)n1 * MB;
+
+ /* my cache is hot */
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+
+ for (int ik = 1; ik < nthr_k; ++ik) {
+ if (ik != ithr_k) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
+ + (dim_t)n1 * MB;
+
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+ }
+ }
+ }
+ });
+ }
+
free(c_buffers);
free(ompstatus_);
free(ws_buffers);
-}
-jit_avx512_common_gemm_f32::jit_avx512_common_gemm_f32(
- char transa, char transb, float beta, bool hasBias)
-{
- transa_ = transa;
- transb_ = transb;
- beta_ = beta;
- hasBias_ = hasBias;
- if (hasBias) {
- assert(beta == 0.0);
- }
- ker_bn_ = new xbyak_gemm(transa, transb, beta, hasBias);
- if (beta != 1.0) {
- ker_b1_ = new xbyak_gemm(transa, transb, 1.0);
- } else {
- ker_b1_ = ker_bn_;
- }
- if (beta != 0.0 || (beta == 0.0 && hasBias)) {
- ker_b0_ = new xbyak_gemm(transa, transb, 0.0);
- } else {
- ker_b0_ = ker_bn_;
- }
-
- nthrs_ = mkldnn_get_max_threads();
+ return mkldnn_success;
}
-jit_avx512_common_gemm_f32::~jit_avx512_common_gemm_f32()
-{
- delete ker_bn_;
- if (beta_ != 1.0)
- delete ker_b1_;
- if (beta_ != 0.0 || (beta_ == 0.0 && hasBias_))
- delete ker_b0_;
-}
}
}
}