Adds test_util.IsMklEnabled() that returns true if TensorFlow has been built with...
authorTatiana Shpeisman <shpeisman@google.com>
Mon, 5 Mar 2018 22:40:46 +0000 (14:40 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Mar 2018 22:44:58 +0000 (14:44 -0800)
PiperOrigin-RevId: 187925038

tensorflow/core/util/port.cc
tensorflow/core/util/port.h
tensorflow/python/framework/test_util.py
tensorflow/python/framework/test_util_test.py
tensorflow/python/tools/print_selective_registration_header_test.py
tensorflow/python/util/port.i

index d93b971..490c584 100644 (file)
@@ -39,4 +39,11 @@ bool CudaSupportsHalfMatMulAndConv() {
 #endif
 }
 
+bool IsMklEnabled() {
+#ifdef INTEL_MKL
+  return true;
+#else
+  return false;
+#endif
+}
 }  // end namespace tensorflow
index ed65341..981def9 100644 (file)
@@ -25,6 +25,9 @@ bool IsGoogleCudaEnabled();
 // half-precision matrix multiplications and convolution operations.
 bool CudaSupportsHalfMatMulAndConv();
 
+// Returns true if INTEL_MKL is defined
+bool IsMklEnabled();
+
 }  // end namespace tensorflow
 
 #endif  // TENSORFLOW_UTIL_PORT_H_
index aabf89a..78252e4 100644 (file)
@@ -200,11 +200,14 @@ def _strip_checkpoint_v2_randomized(graph_def):
 def IsGoogleCudaEnabled():
   return pywrap_tensorflow.IsGoogleCudaEnabled()
 
-
 def CudaSupportsHalfMatMulAndConv():
   return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv()
 
 
+def IsMklEnabled():
+  return pywrap_tensorflow.IsMklEnabled()
+
+
 def InstallStackTraceHandler():
   pywrap_tensorflow.InstallStacktraceHandler()
 
index a717eb3..20d8160 100644 (file)
@@ -82,6 +82,14 @@ class TestUtilTest(test_util.TensorFlowTestCase):
     else:
       print("GoogleCuda is disabled")
 
+  def testIsMklEnabled(self):
+    # This test doesn't assert anything.
+    # It ensures the py wrapper function is generated correctly.
+    if test_util.IsMklEnabled():
+      print("MKL is enabled")
+    else:
+      print("MKL is disabled")
+
   def testAssertProtoEqualsStr(self):
 
     graph_str = "node { name: 'w1' op: 'params' }"
index 36978b0..4b3d982 100644 (file)
@@ -24,6 +24,7 @@ import sys
 from google.protobuf import text_format
 
 from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import test_util
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
 from tensorflow.python.tools import selective_registration_header_lib
@@ -93,11 +94,16 @@ class PrintOpFilegroupTest(test.TestCase):
 
     ops_and_kernels = selective_registration_header_lib.get_ops_and_kernels(
         'rawproto', self.WriteGraphFiles(graphs), default_ops)
+    matmul_prefix = ''
+    if test_util.IsMklEnabled():
+      matmul_prefix = 'Mkl'
+
     self.assertListEqual(
         [
             ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
-            ('MatMul', 'MatMulOp<CPUDevice, double, false >'),  #
-            ('MatMul', 'MatMulOp<CPUDevice, float, false >'),  #
+            ('MatMul',
+             matmul_prefix + 'MatMulOp<CPUDevice, double, false >'),  #
+            ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'),  #
             ('NoOp', 'NoOp'),  #
             ('Reshape', 'ReshapeOp'),  #
             ('_Recv', 'RecvOp'),  #
@@ -112,8 +118,9 @@ class PrintOpFilegroupTest(test.TestCase):
     self.assertListEqual(
         [
             ('BiasAdd', 'BiasOp<CPUDevice, float>'),  #
-            ('MatMul', 'MatMulOp<CPUDevice, double, false >'),  #
-            ('MatMul', 'MatMulOp<CPUDevice, float, false >'),  #
+            ('MatMul',
+             matmul_prefix + 'MatMulOp<CPUDevice, double, false >'),  #
+            ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, float, false >'),  #
             ('NoOp', 'NoOp'),  #
             ('Reshape', 'ReshapeOp'),  #
             ('_Recv', 'RecvOp'),  #
index cea4d84..2f73073 100644 (file)
@@ -23,5 +23,6 @@ limitations under the License.
 %unignore tensorflow;
 %unignore tensorflow::IsGoogleCudaEnabled;
 %unignore tensorflow::CudaSupportsHalfMatMulAndConv;
+%unignore tensorflow::IsMklEnabled;
 %include "tensorflow/core/util/port.h"
 %unignoreall