[CPU][BF16] bf16 for Gemm or MatMul was enabled (#1920)
authorAlexey Varyzgin <alexey.varyzgin@intel.com>
Fri, 4 Sep 2020 07:04:02 +0000 (10:04 +0300)
committerGitHub <noreply@github.com>
Fri, 4 Sep 2020 07:04:02 +0000 (10:04 +0300)
inference-engine/include/ie_precision.hpp
inference-engine/src/mkldnn_plugin/bf16transformer.h
inference-engine/src/mkldnn_plugin/nodes/mkldnn_gemm_node.cpp
inference-engine/thirdparty/mkl-dnn

index 2c8f3fe..8d13a4b 100644 (file)
@@ -224,7 +224,7 @@ public:
                (precisionInfo.value == Precision::Q78) || (precisionInfo.value == Precision::I16) ||
                (precisionInfo.value == Precision::I8) || (precisionInfo.value == Precision::I32) ||
                (precisionInfo.value == Precision::I64) || (precisionInfo.value == Precision::BIN) ||
-               (precisionInfo.value == Precision::CUSTOM);
+               (precisionInfo.value == Precision::BF16) || (precisionInfo.value == Precision::CUSTOM);
     }
 
 protected:
index 370656e..6ff30cd 100644 (file)
@@ -13,7 +13,7 @@ namespace MKLDNNPlugin {
 
 class BF16Transformer {
     const InferenceEngine::details::caseless_set<std::string> _initbf16 =
-        { "convolution", "fullyconnected", "innerproduct" };
+        { "convolution", "fullyconnected", "innerproduct", "gemm" };
     const InferenceEngine::details::caseless_set<std::string> _complementbf16 =
         { "relu", "tanh", "elu", "square", "abs", "sqrt", "linear", "bounded_relu", "soft_relu", "logistic",
           "exp", "gelu", "clamp", "swish", "prelu", "pooling", "norm", "gather", "memory" };
index 1b003fc..a1c42af 100644 (file)
@@ -123,8 +123,13 @@ void MKLDNNGemmNode::initSupportedPrimitiveDescriptors() {
     auto inPrec0 = getCnnLayer()->insData[0].lock()->getPrecision();
     auto inPrec1 = getCnnLayer()->insData[1].lock()->getPrecision();
     if ((inPrec0 != Precision::U8 && inPrec0 != Precision::I8) || inPrec1 != Precision::I8 || isThreeInputs) {
-        inPrec0 = Precision::FP32;
-        inPrec1 = Precision::FP32;
+        if (inPrec0 == Precision::BF16 || inPrec1 == Precision::BF16) {
+            inPrec0 = Precision::BF16;
+            inPrec1 = Precision::BF16;
+        } else {
+            inPrec0 = Precision::FP32;
+            inPrec1 = Precision::FP32;
+        }
     }
 
     auto inputDataType0 = MKLDNNExtensionUtils::IEPrecisionToDataType(inPrec0);
@@ -193,6 +198,11 @@ inline void process_gemm(char transa, char transb, int M, int N, int K, float al
     mkldnn_sgemm(transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
 }
 
+inline void process_gemm(char transa, char transb, int M, int N, int K, float alpha, const uint16_t *A, int lda,
+                         const uint16_t *B, int ldb, float beta, float *C, int ldc) {
+    mkldnn_gemm_bf16bf16f32(transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
+}
+
 inline void process_gemm(char transa, char transb, int M, int N, int K, float alpha, const uint8_t *A, int lda,
                          const int8_t *B, int ldb, float beta, float *C, int ldc) {
     const int32_t co = 0;
@@ -289,6 +299,9 @@ void MKLDNNGemmNode::execute(mkldnn::stream strm) {
         case Precision::FP32:
             process_data<float, float>();
             break;
+        case Precision::BF16:
+            process_data<uint16_t, uint16_t>();
+            break;
         case Precision::I8:
             process_data<int8_t, int8_t>();
             break;
index b73474c..b96a547 160000 (submodule)
@@ -1 +1 @@
-Subproject commit b73474c80c21ae170b112803a1fc315e1549bdab
+Subproject commit b96a54762aedf27711f8e9144d05b2697b03cc40