[Contrib] Add MKL DNN option (#4323)
authorHaichen Shen <shenhaichen@gmail.com>
Fri, 15 Nov 2019 03:45:57 +0000 (19:45 -0800)
committerTianqi Chen <tqchen@users.noreply.github.com>
Fri, 15 Nov 2019 03:45:57 +0000 (19:45 -0800)
* [Contrib] Add MKL DNN

* update

* update

CMakeLists.txt
cmake/config.cmake
cmake/modules/contrib/BLAS.cmake
src/runtime/contrib/cblas/cblas.cc
topi/python/topi/x86/dense.py

index 2bea818..c99fe0d 100644 (file)
@@ -53,6 +53,7 @@ tvm_option(PICOJSON_PATH "Path to PicoJSON" "3rdparty/picojson")
 # Contrib library options
 tvm_option(USE_BLAS "The blas library to be linked" none)
 tvm_option(USE_MKL_PATH "MKL root path when use MKL blas" none)
+tvm_option(USE_MKLDNN "Build with MKLDNN" OFF)
 tvm_option(USE_CUDNN "Build with cuDNN" OFF)
 tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
 tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
index 51c9292..119754a 100644 (file)
@@ -115,6 +115,9 @@ set(USE_BLAS none)
 # set(USE_MKL_PATH <path to venv or site-packages directory>) if using `pip install mkl`
 set(USE_MKL_PATH none)
 
+# Whether use MKLDNN library
+set(USE_MKLDNN OFF)
+
 # Whether use OpenMP thread pool, choices: gnu, intel
 # Note: "gnu" uses gomp library, "intel" uses iomp5 library
 set(USE_OPENMP none)
index 6a58287..bd8c0d0 100644 (file)
@@ -55,3 +55,10 @@ elseif(USE_BLAS STREQUAL "none")
 else()
   message(FATAL_ERROR "Invalid option: USE_BLAS=" ${USE_BLAS})
 endif()
+
+if(USE_MKLDNN STREQUAL "ON")
+  find_library(BLAS_LIBRARY_MKLDNN dnnl)
+  list(APPEND TVM_RUNTIME_LINKER_LIBS ${BLAS_LIBRARY_MKLDNN})
+  add_definitions(-DUSE_DNNL=1)
+  message(STATUS "Use MKLDNN library " ${BLAS_LIBRARY_MKLDNN})
+endif()
index c655867..ef9f5d6 100644 (file)
@@ -31,6 +31,9 @@ extern "C" {
 #else
 #include <cblas.h>
 #endif
+#if USE_DNNL == 1
+#include <dnnl.h>
+#endif
 }
 
 namespace tvm {
@@ -40,12 +43,19 @@ using namespace runtime;
 
 inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; }
 
+inline char BooleanToTransposeChar(bool trans) { return trans ? 'T' : 'N'; }
+
 struct CblasSgemmOp {
   typedef float TDatatype;
   void operator()(bool ta, bool tb, int M, int N, int K, float alpha, float* A, int lda, float* B,
                   int ldb, float beta, float* C, int ldc) {
+#if USE_DNNL == 1
+    dnnl_sgemm(BooleanToTransposeChar(tb), BooleanToTransposeChar(ta), N, M, K, alpha, B,
+               ldb, A, lda, beta, C, ldc);
+#else
     cblas_sgemm(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), M, N, K, alpha, A,
                 lda, B, ldb, beta, C, ldc);
+#endif
   }
 };
 
index 2a739d5..605a175 100644 (file)
@@ -32,7 +32,7 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
     if "cblas" in target.libs:
         C = cblas.matmul(data, weight, False, True)
         if bias is not None:
-            C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j].astype(out_dtype),
+            C = tvm.compute(C.shape, lambda i, j: C[i, j] + bias[j],
                             tag=tag.BROADCAST)
         return C