namespace mkldnn {
-struct test_params {
+struct test_igemm_params {
char offsetc;
+ bool zero_oa;
+ bool zero_ob;
+ bool zero_oc;
+};
+
+struct test_params {
char transA;
char transB;
int M;
int ldb;
int ldc;
+ test_igemm_params igemm_params;
bool expect_to_fail;
mkldnn_status_t expected_status;
};
template <typename b_dt>
void ref_gemm_s8x8s32(const char *transa, const char *transb,
const char *offsetc, int m, int n, int k, const float alpha,
- int8_t *A, int lda, const int8_t *ao, b_dt *B, int ldb,
- const int8_t *bo, const float beta, int32_t *C, int ldc,
- const int32_t *co) {
+ int8_t *A, int lda, const int8_t *oa, b_dt *B, int ldb,
+ const int8_t *ob, const float beta, int32_t *C, int ldc,
+ const int32_t *oc) {
bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
const int a_cols = AisN ? k : m;
mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) {
da_setter(i, j,
- static_cast<double>(ia_accessor(i, j)) + static_cast<double>(ao[0]));
+ static_cast<double>(ia_accessor(i, j)) + static_cast<double>(oa[0]));
});
const int b_rows = BisN ? k : n;
const int b_cols = BisN ? n : k;
mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) {
db_setter(i, j,
- static_cast<double>(ib_accessor(i, j)) + static_cast<double>(bo[0]));
+ static_cast<double>(ib_accessor(i, j)) + static_cast<double>(ob[0]));
});
ref_gemm(transa, transb, m, n, k, 1.0, dA, lda, dB, ldb, 0.0, dC, ldc);
auto f2d = [=] (float v) { return static_cast<double>(v); };
mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) {
- double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]);
+ double coffset = OCisR ? i2d(oc[j]) : OCisC ? i2d(oc[i]) : i2d(oc[0]);
double val = ((beta == 0.0f) ? 0.0 : f2d(beta) * i2d(C[i + j * ldc]))
+ f2d(alpha) * dC[i + j * ldc] + coffset;
C[i + j * ldc] =
test_free((char *)dC);
}
-template <typename T>
-void compare(int M, int N, int ldc, T *C, T *C_ref, int K = 1) {
- mkldnn::impl::parallel_nd(N, ldc, [&](int i, int j) {
- T ref = C_ref[i*ldc + j];
- T got = C[i*ldc + j];
- T diff = got - ref;
- if (data_traits<T>::data_type == memory::data_type::f32) {
- T e = (std::abs(ref) > 1e-4) ? diff / ref : diff;
- EXPECT_NEAR(e, 0.0, 1e-4)
- << "Row: " << j << " Column: " << i;
+template <typename b_dt, typename c_dt>
+void compare(int m, int n, const c_dt *c, const c_dt *c_ref, int ldc,
+ float alpha = 1.0f, float beta = 0.0f, int k = 1) {
+ using data_type = memory::data_type;
+ mkldnn::impl::parallel_nd(n, ldc, [&](int i, int j) {
+ c_dt ref = c_ref[i*ldc + j];
+ c_dt got = c[i*ldc + j];
+ c_dt diff = got - ref;
+
+ if (data_traits<b_dt>::data_type == data_type::f32) {
+ c_dt e = (std::abs(ref) > 1e-4) ? diff / ref : diff;
+ EXPECT_NEAR(e, 0.0, 1e-4) << "Row: " << j << " Col: " << i;
} else {
- T eps = K / 1000 + 1;
- EXPECT_NEAR(diff, 0, eps)
- << "Row: " << j << " Column: " << i;
+ // igemm
+ if (alpha == 1.0f) {
+ EXPECT_NEAR(diff, 0, 1) << "Row: " << j << " Col: " << i;
+ } else {
+ if (data_traits<b_dt>::data_type == data_type::u8) {
+ c_dt eps = k / 1000 + 1;
+ EXPECT_NEAR(diff, 0, eps) << "Row: " << j << " Col: " << i;
+ } else if (data_traits<b_dt>::data_type == data_type::s8) {
+ c_dt eps = k / 500 + 1;
+ EXPECT_NEAR(diff, 0, eps) << "Row: " << j << " Col: " << i;
+ }
+ }
}
});
}
}
template <typename a_dt, typename b_dt, typename c_dt>
-inline void fill_matrix(size_t sizeA, size_t sizeB, size_t sizeC, size_t sizeco,
- a_dt *A, b_dt *B, c_dt *C, a_dt *ao, a_dt *bo, c_dt *co) {
+inline void fill_matrix(const test_params &p, size_t sizeA, size_t sizeB,
+ size_t sizeC, size_t sizeco, a_dt *A, b_dt *B, c_dt *C, a_dt *oa,
+ a_dt *ob, c_dt *oc) {
fill_data<a_dt>(sizeA, A);
fill_data<b_dt>(sizeB, B);
fill_data<c_dt>(sizeC, C);
- if (ao != nullptr && bo != nullptr && co != nullptr) {
- fill_data<a_dt>(1, ao);
- fill_data<a_dt>(1, bo);
- fill_data<c_dt>(sizeco, co);
+ if (oa != nullptr && ob != nullptr && oc != nullptr) {
+ if (p.igemm_params.zero_oa) (*oa) = 0;
+ else fill_data<a_dt>(1, oa);
+
+ if (p.igemm_params.zero_ob) (*ob) = 0;
+ else fill_data<a_dt>(1, ob);
+
+ if (p.igemm_params.zero_oc) {
+ for (size_t i = 0; i < sizeco; i++)
+ oc[i] = 0;
+ } else fill_data<c_dt>(sizeco, oc);
}
}
int32_t *C = get_matrix_buffer<int32_t>(sizeC);
int32_t *C_ref = get_matrix_buffer<int32_t>(sizeC);
- bool OCisR = (p.offsetc == 'R' || p.offsetc == 'r');
- bool OCisC = (p.offsetc == 'C' || p.offsetc == 'c');
+ bool OCisR = (p.igemm_params.offsetc == 'R' || p.igemm_params.offsetc == 'r');
+ bool OCisC = (p.igemm_params.offsetc == 'C' || p.igemm_params.offsetc == 'c');
size_t sizeco = OCisR ? p.N : OCisC ? p.M : 1;
- int8_t ao, bo;
- int32_t *co = get_matrix_buffer<int32_t>(sizeco);
+ int8_t oa, ob;
+ int32_t *oc = get_matrix_buffer<int32_t>(sizeco);
- fill_matrix<int8_t, uint8_t, int32_t>(sizeA, sizeB, sizeC, sizeco, A, B, C,
- &ao, &bo, co);
+ fill_matrix<int8_t, uint8_t, int32_t>(p, sizeA, sizeB, sizeC, sizeco,
+ A, B, C, &oa, &ob, oc);
mkldnn::impl::parallel_nd(p.ldc * p.N,
[&](int i) { C_ref[i] = static_cast<int32_t>(C[i]); });
- auto status = mkldnn_gemm_s8u8s32(&p.transA, &p.transB, &p.offsetc,
- &p.M, &p.N, &p.K, &p.alpha, A, &p.lda, &ao, B, &p.ldb, &bo,
- &p.beta, C, &p.ldc, co);
+ auto status = mkldnn_gemm_s8u8s32(&p.transA, &p.transB, &p.igemm_params.offsetc,
+ &p.M, &p.N, &p.K, &p.alpha, A, &p.lda, &oa, B, &p.ldb, &ob,
+ &p.beta, C, &p.ldc, oc);
if (status != mkldnn_success)
throw error(status, "mkldnn_gemm_s8u8s32 returned error");
- ref_gemm_s8x8s32<uint8_t>(&p.transA, &p.transB, &p.offsetc, p.M, p.N,
- p.K, p.alpha, A, p.lda, &ao, B, p.ldb, &bo, p.beta, C_ref,
- p.ldc, co);
+ ref_gemm_s8x8s32<uint8_t>(&p.transA, &p.transB, &p.igemm_params.offsetc, p.M, p.N,
+ p.K, p.alpha, A, p.lda, &oa, B, p.ldb, &ob, p.beta, C_ref,
+ p.ldc, oc);
- compare(p.M, p.N, p.ldc, C, C_ref, p.K);
+ compare<uint8_t, int32_t>(p.M, p.N, C, C_ref, p.ldc, p.alpha, p.beta, p.K);
test_free((char *)A);
test_free((char *)B);
test_free((char *)C);
test_free((char *)C_ref);
- test_free((char *)co);
+ test_free((char *)oc);
}
template <>
int32_t *C = get_matrix_buffer<int32_t>(sizeC);
int32_t *C_ref = get_matrix_buffer<int32_t>(sizeC);
- bool OCisR = (p.offsetc == 'R' || p.offsetc == 'r');
- bool OCisC = (p.offsetc == 'C' || p.offsetc == 'c');
+ bool OCisR = (p.igemm_params.offsetc == 'R' || p.igemm_params.offsetc == 'r');
+ bool OCisC = (p.igemm_params.offsetc == 'C' || p.igemm_params.offsetc == 'c');
size_t sizeco = OCisR ? p.N : OCisC ? p.M : 1;
- int8_t ao, bo;
- int32_t* co = get_matrix_buffer<int32_t>(sizeco);
+ int8_t oa, ob;
+ int32_t* oc = get_matrix_buffer<int32_t>(sizeco);
- fill_matrix<int8_t, int8_t, int32_t>(sizeA, sizeB, sizeC, sizeco, A, B, C,
- &ao, &bo, co);
+ fill_matrix<int8_t, int8_t, int32_t>(p, sizeA, sizeB, sizeC, sizeco, A, B, C,
+ &oa, &ob, oc);
mkldnn::impl::parallel_nd(p.ldc * p.N,
[&](int i) { C_ref[i] = static_cast<int32_t>(C[i]); });
- auto status = mkldnn_gemm_s8s8s32(&p.transA, &p.transB, &p.offsetc,
- &p.M, &p.N, &p.K, &p.alpha, A, &p.lda, &ao, B, &p.ldb, &bo,
- &p.beta, C, &p.ldc, co);
+ auto status = mkldnn_gemm_s8s8s32(&p.transA, &p.transB, &p.igemm_params.offsetc,
+ &p.M, &p.N, &p.K, &p.alpha, A, &p.lda, &oa, B, &p.ldb, &ob,
+ &p.beta, C, &p.ldc, oc);
if (status != mkldnn_success)
throw error(status, "mkldnn_gemm_s8s8s32 returned error");
- ref_gemm_s8x8s32<int8_t>(&p.transA, &p.transB, &p.offsetc, p.M, p.N,
- p.K, p.alpha, A, p.lda, &ao, B, p.ldb, &bo, p.beta, C_ref,
- p.ldc, co);
+ ref_gemm_s8x8s32<int8_t>(&p.transA, &p.transB, &p.igemm_params.offsetc, p.M, p.N,
+ p.K, p.alpha, A, p.lda, &oa, B, p.ldb, &ob, p.beta, C_ref,
+ p.ldc, oc);
- compare(p.M, p.N, p.ldc, C, C_ref, p.K);
+ compare<int8_t, int32_t>(p.M, p.N, C, C_ref, p.ldc, p.alpha, p.beta, p.K);
test_free((char *)A);
test_free((char *)B);
test_free((char *)C);
test_free((char *)C_ref);
- test_free((char *)co);
+ test_free((char *)oc);
}
template <>
float *C = get_matrix_buffer<float>(sizeC);
float *C_ref = get_matrix_buffer<float>(sizeC);
- fill_matrix<float, float, float>(sizeA, sizeB, sizeC, 0, A, B, C,
+ fill_matrix<float, float, float>(p, sizeA, sizeB, sizeC, 0, A, B, C,
nullptr, nullptr, nullptr);
mkldnn::impl::parallel_nd(p.N * p.ldc, [&](int i) { C_ref[i] = C[i]; });
if (status == mkldnn_success) {
ref_gemm(&p.transA, &p.transB, p.M, p.N, p.K, p.alpha, A, p.lda, B, p.ldb,
p.beta, C_ref, p.ldc);
- compare(p.M, p.N, p.ldc, C, C_ref);
+ compare<float, float>(p.M, p.N, C, C_ref, p.ldc);
}
test_free((char *)A);