Implement constant-only ListDiff Op in XLA to support dense layer.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 11 May 2018 22:07:24 +0000 (15:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 22:15:50 +0000 (15:15 -0700)
PiperOrigin-RevId: 196315170

tensorflow/compiler/tests/BUILD
tensorflow/compiler/tests/listdiff_op_test.py [new file with mode: 0644]
tensorflow/compiler/tf2xla/kernels/BUILD
tensorflow/compiler/tf2xla/kernels/listdiff_op.cc [new file with mode: 0644]

index 9791792..96dfc8d 100644 (file)
@@ -410,6 +410,21 @@ tf_xla_py_test(
 )
 
 tf_xla_py_test(
+    name = "listdiff_op_test",
+    size = "small",
+    srcs = ["listdiff_op_test.py"],
+    deps = [
+        ":xla_test",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:data_flow_ops",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:platform_test",
+        "@six_archive//:six",
+    ],
+)
+
+tf_xla_py_test(
     name = "lrn_ops_test",
     size = "medium",
     srcs = ["lrn_ops_test.py"],
diff --git a/tensorflow/compiler/tests/listdiff_op_test.py b/tensorflow/compiler/tests/listdiff_op_test.py
new file mode 100644 (file)
index 0000000..45a04f0
--- /dev/null
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for XLA listdiff operator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange  # pylint: disable=redefined-builtin
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class ListDiffTest(xla_test.XLATestCase):
+
+  def _testListDiff(self, x, y, out, idx):
+    for dtype in [dtypes.int32, dtypes.int64]:
+      for index_dtype in [dtypes.int32, dtypes.int64]:
+        with self.test_session() as sess:
+          x_tensor = ops.convert_to_tensor(x, dtype=dtype)
+          y_tensor = ops.convert_to_tensor(y, dtype=dtype)
+          with self.test_scope():
+            out_tensor, idx_tensor = array_ops.listdiff(
+                x_tensor, y_tensor, out_idx=index_dtype)
+            tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+        self.assertAllEqual(out, tf_out)
+        self.assertAllEqual(idx, tf_idx)
+        self.assertEqual(1, out_tensor.get_shape().ndims)
+        self.assertEqual(1, idx_tensor.get_shape().ndims)
+
+  def testBasic1(self):
+    self._testListDiff(x=[1, 2, 3, 4], y=[1, 2], out=[3, 4], idx=[2, 3])
+
+  def testBasic2(self):
+    self._testListDiff(x=[1, 2, 3, 4], y=[2], out=[1, 3, 4], idx=[0, 2, 3])
+
+  def testBasic3(self):
+    self._testListDiff(x=[1, 4, 3, 2], y=[4, 2], out=[1, 3], idx=[0, 2])
+
+  def testDuplicates(self):
+    self._testListDiff(x=[1, 2, 4, 3, 2, 3, 3, 1],
+                       y=[4, 2],
+                       out=[1, 3, 3, 3, 1],
+                       idx=[0, 3, 5, 6, 7])
+
+  def testRandom(self):
+    num_random_tests = 10
+    int_low = -7
+    int_high = 8
+    max_size = 50
+    for _ in xrange(num_random_tests):
+      x_size = np.random.randint(max_size + 1)
+      x = np.random.randint(int_low, int_high, size=x_size)
+      y_size = np.random.randint(max_size + 1)
+      y = np.random.randint(int_low, int_high, size=y_size)
+      out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y]
+      if out_idx:
+        out, idx = map(list, zip(*out_idx))
+      else:
+        out = []
+        idx = []
+      self._testListDiff(list(x), list(y), out, idx)
+
+  def testFullyOverlapping(self):
+    self._testListDiff(x=[1, 2, 3, 4], y=[1, 2, 3, 4], out=[], idx=[])
+
+  def testNonOverlapping(self):
+    self._testListDiff(x=[1, 2, 3, 4],
+                       y=[5, 6],
+                       out=[1, 2, 3, 4],
+                       idx=[0, 1, 2, 3])
+
+  def testEmptyX(self):
+    self._testListDiff(x=[], y=[1, 2], out=[], idx=[])
+
+  def testEmptyY(self):
+    self._testListDiff(x=[1, 2, 3, 4], y=[], out=[1, 2, 3, 4], idx=[0, 1, 2, 3])
+
+  def testEmptyXY(self):
+    self._testListDiff(x=[], y=[], out=[], idx=[])
+
+
+if __name__ == "__main__":
+  test.main()
index 85ab4c4..e6da157 100644 (file)
@@ -45,6 +45,7 @@ tf_kernel_library(
         "image_resize_ops.cc",
         "index_ops.cc",
         "l2loss_op.cc",
+        "listdiff_op.cc",
         "lrn_ops.cc",
         "matmul_op.cc",
         "matrix_band_part_op.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc
new file mode 100644 (file)
index 0000000..0388b4c
--- /dev/null
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific ListDiff Op. This only supports constant DT_INT32 and DT_INT64
+// input.
+
+#include <unordered_set>
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace {
+
+constexpr std::array<DataType, 2> kListDiffTypes = {DT_INT32, DT_INT64};
+
+// ListDiffOp is an XLA kernel that supports constant-only x and y input.
+class ListDiffOp : public XlaOpKernel {
+ public:
+  explicit ListDiffOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+
+  void Compile(XlaOpKernelContext* context) override {
+    OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(0)),
+                errors::InvalidArgument("ListDiff expects x as a vector, not ",
+                                        context->InputShape(0).DebugString()));
+
+    OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(1)),
+                errors::InvalidArgument("ListDiff expects y as a vector, not ",
+                                        context->InputShape(1).DebugString()));
+
+    DataType val_type = context->expected_output_dtype(0);
+    DataType idx_type = context->expected_output_dtype(1);
+
+    Status status;
+    switch (val_type) {
+      case DT_INT32:
+        status = ListDiffWithIndexType<int32>(context, idx_type);
+        break;
+      case DT_INT64:
+        status = ListDiffWithIndexType<int64>(context, idx_type);
+        break;
+      default:
+        // This should never happen since we restrict this kernel to only match
+        // inputs with supported Tensor datatype.
+        status = errors::InvalidArgument("ListDiff expects x and y as either ",
+                                         "int32 or int64, not ",
+                                         DataTypeString(val_type));
+    }
+    OP_REQUIRES_OK(context, status);
+  }
+
+ private:
+  template <typename Tval, typename Tidx>
+  Status ListDiff(XlaOpKernelContext* context) {
+    std::vector<int64> x_input, y_input;
+    TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(0, &x_input));
+    TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(1, &y_input));
+
+    std::unordered_set<Tval> y_input_set;
+    y_input_set.reserve(y_input.size());
+    for (auto y : y_input) {
+      y_input_set.insert(y);
+    }
+
+    std::vector<Tval> val_output;
+    std::vector<Tidx> idx_output;
+    auto x_size = x_input.size();
+    for (Tidx i = 0; i < x_size; ++i) {
+      if (y_input_set.count(x_input[i]) > 0) {
+        continue;
+      }
+      val_output.push_back(x_input[i]);
+      idx_output.push_back(i);
+    }
+
+    context->SetOutput(0, context->builder()->ConstantR1<Tval>(val_output));
+    context->SetOutput(1, context->builder()->ConstantR1<Tidx>(idx_output));
+    return Status::OK();
+  }
+
+  template <typename Tval>
+  Status ListDiffWithIndexType(XlaOpKernelContext* context, DataType idx_type) {
+    switch (idx_type) {
+      case DT_INT32:
+        return ListDiff<Tval, int32>(context);
+      case DT_INT64:
+        return ListDiff<Tval, int64>(context);
+      default:
+        return errors::InvalidArgument(
+            "ListDiff expects idx_out as either int32 or int64, not ",
+            DataTypeString(idx_type));
+    }
+  }
+};
+
+REGISTER_XLA_OP(Name("ListDiff")
+                    .TypeConstraint("T", kListDiffTypes)
+                    .CompileTimeConstInput("x")
+                    .CompileTimeConstInput("y"),
+                ListDiffOp);
+
+}  // namespace
+}  // namespace tensorflow