--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for XLA listdiff operator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class ListDiffTest(xla_test.XLATestCase):
+
+ def _testListDiff(self, x, y, out, idx):
+ for dtype in [dtypes.int32, dtypes.int64]:
+ for index_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session() as sess:
+ x_tensor = ops.convert_to_tensor(x, dtype=dtype)
+ y_tensor = ops.convert_to_tensor(y, dtype=dtype)
+ with self.test_scope():
+ out_tensor, idx_tensor = array_ops.listdiff(
+ x_tensor, y_tensor, out_idx=index_dtype)
+ tf_out, tf_idx = sess.run([out_tensor, idx_tensor])
+ self.assertAllEqual(out, tf_out)
+ self.assertAllEqual(idx, tf_idx)
+ self.assertEqual(1, out_tensor.get_shape().ndims)
+ self.assertEqual(1, idx_tensor.get_shape().ndims)
+
+ def testBasic1(self):
+ self._testListDiff(x=[1, 2, 3, 4], y=[1, 2], out=[3, 4], idx=[2, 3])
+
+ def testBasic2(self):
+ self._testListDiff(x=[1, 2, 3, 4], y=[2], out=[1, 3, 4], idx=[0, 2, 3])
+
+ def testBasic3(self):
+ self._testListDiff(x=[1, 4, 3, 2], y=[4, 2], out=[1, 3], idx=[0, 2])
+
+ def testDuplicates(self):
+ self._testListDiff(x=[1, 2, 4, 3, 2, 3, 3, 1],
+ y=[4, 2],
+ out=[1, 3, 3, 3, 1],
+ idx=[0, 3, 5, 6, 7])
+
+ def testRandom(self):
+ num_random_tests = 10
+ int_low = -7
+ int_high = 8
+ max_size = 50
+ for _ in xrange(num_random_tests):
+ x_size = np.random.randint(max_size + 1)
+ x = np.random.randint(int_low, int_high, size=x_size)
+ y_size = np.random.randint(max_size + 1)
+ y = np.random.randint(int_low, int_high, size=y_size)
+ out_idx = [(entry, pos) for pos, entry in enumerate(x) if entry not in y]
+ if out_idx:
+ out, idx = map(list, zip(*out_idx))
+ else:
+ out = []
+ idx = []
+ self._testListDiff(list(x), list(y), out, idx)
+
+ def testFullyOverlapping(self):
+ self._testListDiff(x=[1, 2, 3, 4], y=[1, 2, 3, 4], out=[], idx=[])
+
+ def testNonOverlapping(self):
+ self._testListDiff(x=[1, 2, 3, 4],
+ y=[5, 6],
+ out=[1, 2, 3, 4],
+ idx=[0, 1, 2, 3])
+
+ def testEmptyX(self):
+ self._testListDiff(x=[], y=[1, 2], out=[], idx=[])
+
+ def testEmptyY(self):
+ self._testListDiff(x=[1, 2, 3, 4], y=[], out=[1, 2, 3, 4], idx=[0, 1, 2, 3])
+
+ def testEmptyXY(self):
+ self._testListDiff(x=[], y=[], out=[], idx=[])
+
+
+if __name__ == "__main__":
+ test.main()
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific ListDiff Op. This only supports constant DT_INT32 and DT_INT64
+// input.
+
+#include <unordered_set>
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace {
+
+constexpr std::array<DataType, 2> kListDiffTypes = {DT_INT32, DT_INT64};
+
+// ListDiffOp is an XLA kernel that supports constant-only x and y input.
+class ListDiffOp : public XlaOpKernel {
+ public:
+ explicit ListDiffOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(0)),
+ errors::InvalidArgument("ListDiff expects x as a vector, not ",
+ context->InputShape(0).DebugString()));
+
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(context->InputShape(1)),
+ errors::InvalidArgument("ListDiff expects y as a vector, not ",
+ context->InputShape(1).DebugString()));
+
+ DataType val_type = context->expected_output_dtype(0);
+ DataType idx_type = context->expected_output_dtype(1);
+
+ Status status;
+ switch (val_type) {
+ case DT_INT32:
+ status = ListDiffWithIndexType<int32>(context, idx_type);
+ break;
+ case DT_INT64:
+ status = ListDiffWithIndexType<int64>(context, idx_type);
+ break;
+ default:
+ // This should never happen since we restrict this kernel to only match
+ // inputs with supported Tensor datatype.
+ status = errors::InvalidArgument("ListDiff expects x and y as either ",
+ "int32 or int64, not ",
+ DataTypeString(val_type));
+ }
+ OP_REQUIRES_OK(context, status);
+ }
+
+ private:
+ template <typename Tval, typename Tidx>
+ Status ListDiff(XlaOpKernelContext* context) {
+ std::vector<int64> x_input, y_input;
+ TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(0, &x_input));
+ TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(1, &y_input));
+
+ std::unordered_set<Tval> y_input_set;
+ y_input_set.reserve(y_input.size());
+ for (auto y : y_input) {
+ y_input_set.insert(y);
+ }
+
+ std::vector<Tval> val_output;
+ std::vector<Tidx> idx_output;
+ auto x_size = x_input.size();
+ for (Tidx i = 0; i < x_size; ++i) {
+ if (y_input_set.count(x_input[i]) > 0) {
+ continue;
+ }
+ val_output.push_back(x_input[i]);
+ idx_output.push_back(i);
+ }
+
+ context->SetOutput(0, context->builder()->ConstantR1<Tval>(val_output));
+ context->SetOutput(1, context->builder()->ConstantR1<Tidx>(idx_output));
+ return Status::OK();
+ }
+
+ template <typename Tval>
+ Status ListDiffWithIndexType(XlaOpKernelContext* context, DataType idx_type) {
+ switch (idx_type) {
+ case DT_INT32:
+ return ListDiff<Tval, int32>(context);
+ case DT_INT64:
+ return ListDiff<Tval, int64>(context);
+ default:
+ return errors::InvalidArgument(
+ "ListDiff expects idx_out as either int32 or int64, not ",
+ DataTypeString(idx_type));
+ }
+ }
+};
+
+REGISTER_XLA_OP(Name("ListDiff")
+ .TypeConstraint("T", kListDiffTypes)
+ .CompileTimeConstInput("x")
+ .CompileTimeConstInput("y"),
+ ListDiffOp);
+
+} // namespace
+} // namespace tensorflow