Do not use SparseMatmul to for bfloat16 as Matmul is already supported.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 19 Mar 2018 18:29:44 +0000 (11:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 19 Mar 2018 18:36:11 +0000 (11:36 -0700)
PiperOrigin-RevId: 189614197

tensorflow/python/ops/math_ops.py

index e18d0e9..c893bf9 100644 (file)
@@ -2093,8 +2093,9 @@ def matmul(a,
       sparse_matmul_types = [dtypes.bfloat16, dtypes.float32]
       use_sparse_matmul = (
           a.dtype in sparse_matmul_types and b.dtype in sparse_matmul_types)
-    if a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16:
-      # matmul currently doesn't handle bfloat16 inputs.
+    if (a.dtype == dtypes.bfloat16 or b.dtype == dtypes.bfloat16 and
+        a.dtype != b.dtype):
+      # matmul currently doesn't handle mixed-precision inputs.
       use_sparse_matmul = True
     if use_sparse_matmul:
       ret = sparse_matmul(