[XLA] Pattern matcher for HLO, Shapes, Layouts
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 13 Apr 2018 22:32:11 +0000 (15:32 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 22:35:03 +0000 (15:35 -0700)
PiperOrigin-RevId: 192834129

tensorflow/compiler/xla/service/BUILD
tensorflow/compiler/xla/service/pattern_matcher.h [new file with mode: 0644]
tensorflow/compiler/xla/service/pattern_matcher_test.cc [new file with mode: 0644]
tensorflow/compiler/xla/shape_util.cc
tensorflow/compiler/xla/shape_util.h

index 65203fa..ddc0998 100644 (file)
@@ -303,6 +303,29 @@ tf_cc_test(
 )
 
 cc_library(
+    name = "pattern_matcher",
+    hdrs = ["pattern_matcher.h"],
+    deps = [
+        ":hlo",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/core:lib",
+    ],
+)
+
+tf_cc_test(
+    name = "pattern_matcher_test",
+    srcs = ["pattern_matcher_test.cc"],
+    deps = [
+        ":hlo",
+        ":pattern_matcher",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+        "//tensorflow/core:test",
+    ],
+)
+
+cc_library(
     name = "hlo_reachability",
     srcs = ["hlo_reachability.cc"],
     hdrs = ["hlo_reachability.h"],
diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h
new file mode 100644 (file)
index 0000000..5d49638
--- /dev/null
@@ -0,0 +1,1014 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace xla {
+
+// A pattern matcher for HloInstructions, Shapes, and Layouts.
+//
+// The Match function's first argument must be HloInstruction*, Shape*, or
+// Layout*. The second argument is a pattern that will be matched against the
+// first argument, as described below.
+//
+// Patterns are constructed using the match::Op, match::Shape, or match::Layout
+// functions. By default, the returned patterns will match any HloInstruction,
+// Shape, or Layout, respectively. However the match can be made more specific
+// by using the pattern's modifier methods, for example:
+//
+//   match::Op().WithOpcode(HloOpcode::kAdd).WithOperand(
+//     0, match::Op().WithOpcode(HloOpcode::kConstant))
+//
+// This pattern will match Add instructions whose first operand is a constant.
+//
+// Each pattern type has the following modifiers:
+//
+//   Op():
+//     - WithName: match operations with the given name
+//     - WithOpcode: match operations with the given opcode
+//     - WithShape: match operations whose shape matches the given pattern
+//     - WithOperand: match operations whose operand matches the given pattern
+//
+//   Shape():
+//     - EqualTo: matches shapes that are equal to the argument
+//     - CompatibleTo: matches shapes that are compatible to the argument
+//     - IsScalar/IsArray/IsTuple: matches scalar/array/tuple shapes
+//     - IsDenseArray/IsSparseArray: matches arrays with dense/sparse format
+//     - WithLayout: match shapes whose layout matches the given pattern
+//     - WithLayoutEqualTo: matches shapes whose layouts equal the argument
+//     - WithSubshape: matches tuple shapes whose subshape matches the given
+//       pattern
+//     - WithSubshapeEqualTo: matches shapes with a subshape equal the argument
+//     - WithElementType: matches array/scalar shapes with the given element
+//       type
+//     - WithRank: matches array/scalar types with the given rank
+//
+//  Layout():
+//     - EqualTo: matches layouts that are equal to the argument
+//     - WithDenseFormat/WithSparseFormat: matches layouts with dense/sparse
+//       format
+//
+// Op(), Shape(), and Layout() may be passed an argument of type
+// HloInstruction**, Shape**, or Layout**, respectively, or const versions of
+// these pointers. If the pattern is matched, the address of the matched value
+// will be "captured" and stored at this location.
+//
+// For example:
+//   HloInstruction* foo = ...;
+//   HloInstruction* matched_operand;
+//   CHECK(Match(foo,
+//               match::Op().WithOperand(0, match::Op(&matched_operand))));
+//
+// Helpers are provided for common nullary, unary, binary, and ternary
+// instructions. These helpers can be called with no arguments, in which case
+// they will match any instruction matching the opcode. They may also be called
+// with matches for the operands and with an optional capture. (The capture must
+// be the first argument.) Some examples of these helpers and their equivalents
+// are provided below.
+//
+// Example nullary instruction:
+//   Recv()                            == Op().WithOpcode(HloOpcode::kRecv)
+//   Recv(&a)                          == Op(&a).WithOpcode(HloOpcode::kRecv)
+//
+// Example unary instruction:
+//   Abs()                             == Op().WithOpcode(HloOpcode::kAbs)
+//   Abs(Op(&a))                       == Op().WithOpcode(HloOpcode::kAbs)
+//                                            .WithOperand(0, Op(&a)))
+//   Abs(&a, Op(&b))                   == Op(&a).WithOpcode(HloOpcode::kAbs)
+//                                              .WithOperand(0, Op(&b))
+//
+// Example binary instruction:
+//   Add()                             == Op().WithOpcode(HloOpcode::kAdd)
+//   Add(Op(&a), Op(&b))               == Op().WithOpcode(HloOpcode::kAdd)
+//                                            .WithOperand(0, Op(&a))
+//                                            .WithOperand(1, Op(&b))
+//   Add(&a, Op(&b), Op(&c))           == Op(&a).WithOpcode(HloOpcode::kAdd)
+//                                              .WithOperand(0, Op(&b))
+//                                              .WithOperand(1, Op(&c))
+//
+// Example ternary instruction:
+//   Clamp()                           == Op().WithOpcode(HloOpcode::kClamp)
+//   Clamp(Op(&a), Op(&b), Op(&c))     == Op().WithOpcode(HloOpcode::kClamp)
+//                                            .WithOperand(0, Op(&a))
+//                                            .WithOperand(1, Op(&b))
+//                                            .WithOperand(2, Op(&c))
+//   Clamp(&a, Op(&b), Op(&c), Op(&d)) == Op(&a).WithOpcode(HloOpcode::kClamp)
+//                                              .WithOperand(0, Op(&b))
+//                                              .WithOperand(1, Op(&c))
+//                                              .WithOperand(2, Op(&d))
+//
+template <typename Value, typename Pattern>
+bool Match(Value* value, const Pattern& pattern) {
+  return pattern.Match(value);
+}
+
+namespace match {
+
+namespace detail {
+
+template <typename LayoutType, typename Impl>
+class LayoutPattern;
+
+// The base LayoutPattern implementation. Matches only if the layout is not
+// nullptr.
+class LayoutPatternBaseImpl {
+ public:
+  bool Match(const ::xla::Layout* layout) const { return layout != nullptr; }
+};
+
+// A LayoutPattern implementation that matches only if the layout equals a
+// Layout proto.
+template <typename Previous>
+class LayoutPatternEqualImpl {
+ public:
+  explicit constexpr LayoutPatternEqualImpl(const Previous& previous,
+                                            const ::xla::Layout* layout)
+      : previous_(previous), layout_(layout) {}
+
+  bool Match(const ::xla::Layout* layout) const {
+    return previous_.Match(layout) && LayoutUtil::Equal(*layout_, *layout);
+  }
+
+ private:
+  Previous previous_;
+  const ::xla::Layout* layout_;
+};
+
+// A LayoutPattern implementation that matches only if the layout has a given
+// format.
+template <typename Previous>
+class LayoutPatternFormatImpl {
+ public:
+  explicit constexpr LayoutPatternFormatImpl(const Previous& previous,
+                                             Format format)
+      : previous_(previous), format_(format) {}
+
+  bool Match(const ::xla::Layout* layout) const {
+    return previous_.Match(layout) && layout->format() == format_;
+  }
+
+ private:
+  Previous previous_;
+  Format format_;
+};
+
+// A pattern that matches Layouts.
+template <typename LayoutType, typename Impl>
+class LayoutPattern {
+ public:
+  explicit constexpr LayoutPattern(const Impl& impl,
+                                   LayoutType** matched_layout)
+      : impl_(impl), matched_layout_(matched_layout) {}
+
+  // Returns true and captures the layout iff it matches the pattern.
+  bool Match(const ::xla::Layout* layout) const {
+    if (impl_.Match(layout)) {
+      if (matched_layout_) {
+        *matched_layout_ = layout;
+      }
+      return true;
+    }
+    return false;
+  }
+
+  // Returns true and captures the layout iff it matches the pattern.
+  bool Match(::xla::Layout* layout) const {
+    if (impl_.Match(layout)) {
+      if (matched_layout_) {
+        *matched_layout_ = layout;
+      }
+      return true;
+    }
+    return false;
+  }
+
+  // Modifies the pattern to match only if the layout equals the given proto.
+  // The layout must outlive the returned pattern.
+  constexpr LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>> EqualTo(
+      const Layout* layout) const {
+    return LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>>(
+        LayoutPatternEqualImpl<Impl>(impl_, layout), matched_layout_);
+  }
+
+  // Modifies the pattern to match only if the layout has a dense format.
+  constexpr LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>
+  WithDenseFormat() const {
+    return LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>(
+        LayoutPatternFormatImpl<Impl>(impl_, DENSE), matched_layout_);
+  }
+
+  // Modifies the pattern to match only if the layout has a sparse format.
+  constexpr LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>
+  WithSparseFormat() const {
+    return LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>(
+        LayoutPatternFormatImpl<Impl>(impl_, SPARSE), matched_layout_);
+  }
+
+ private:
+  Impl impl_;
+  LayoutType** matched_layout_;
+};
+
+}  // namespace detail
+
+// Creates a layout pattern that will capture the matched layout in the
+// argument.
+inline constexpr detail::LayoutPattern<const ::xla::Layout,
+                                       detail::LayoutPatternBaseImpl>
+Layout(const ::xla::Layout** matched_layout = nullptr) {
+  return detail::LayoutPattern<const ::xla::Layout,
+                               detail::LayoutPatternBaseImpl>(
+      detail::LayoutPatternBaseImpl(), matched_layout);
+}
+
+// Creates a layout pattern that will capture the matched layout in the
+// argument.
+inline constexpr detail::LayoutPattern<::xla::Layout,
+                                       detail::LayoutPatternBaseImpl>
+Layout(::xla::Layout** matched_layout) {
+  return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>(
+      detail::LayoutPatternBaseImpl(), matched_layout);
+}
+
+namespace detail {
+
+template <typename ShapeType, typename Impl>
+class ShapePattern;
+
+// The base ShapePattern implementation. Matches only if the shape is not
+// nullptr.
+class ShapePatternBaseImpl {
+ public:
+  bool Match(const ::xla::Shape* shape) const { return shape != nullptr; }
+};
+
+// A ShapePattern implementation that matches only if the shape equals a Shape
+// proto.
+template <typename Previous>
+class ShapePatternEqualImpl {
+ public:
+  explicit constexpr ShapePatternEqualImpl(const Previous& previous,
+                                           const ::xla::Shape* shape)
+      : previous_(previous), shape_(shape) {}
+
+  bool Match(const ::xla::Shape* shape) const {
+    return previous_.Match(shape) && ShapeUtil::Equal(*shape_, *shape);
+  }
+
+ private:
+  Previous previous_;
+  const ::xla::Shape* shape_;
+};
+
+// A ShapePattern implementation that matches only if the shape is compatible to
+// a Shape proto.
+template <typename Previous>
+class ShapePatternCompatibleImpl {
+ public:
+  explicit constexpr ShapePatternCompatibleImpl(const Previous& previous,
+                                                const ::xla::Shape* shape)
+      : previous_(previous), shape_(shape) {}
+
+  bool Match(const ::xla::Shape* shape) const {
+    return previous_.Match(shape) && ShapeUtil::Compatible(*shape_, *shape);
+  }
+
+ private:
+  Previous previous_;
+  const ::xla::Shape* shape_;
+};
+
+// A ShapePattern implementation that matches only if the shape has a given
+// element type.
+template <typename Previous>
+class ShapePatternElementTypeImpl {
+ public:
+  explicit constexpr ShapePatternElementTypeImpl(const Previous& previous,
+                                                 PrimitiveType element_type)
+      : previous_(previous), element_type_(element_type) {}
+
+  bool Match(const ::xla::Shape* shape) const {
+    return previous_.Match(shape) && shape->element_type() == element_type_;
+  }
+
+ private:
+  Previous previous_;
+  PrimitiveType element_type_;
+};
+
+// A ShapePattern implementation that matches only if the shape is scalar.
+template <typename Previous>
+class ShapePatternIsScalarImpl {
+ public:
+  explicit constexpr ShapePatternIsScalarImpl(const Previous& previous)
+      : previous_(previous) {}
+
+  bool Match(const ::xla::Shape* shape) const {
+    return previous_.Match(shape) && ShapeUtil::IsScalar(*shape);
+  }
+
+ private:
+  Previous previous_;
+};
+
+// A ShapePattern implementation that matches only if the shape is an array
+template <typename Previous>
+class ShapePatternIsArrayImpl {
+ public:
+  explicit constexpr ShapePatternIsArrayImpl(const Previous& previous)
+      : previous_(previous) {}
+
+  bool Match(const ::xla::Shape* shape) const {
+    return previous_.Match(shape) && ShapeUtil::IsArray(*shape);
+  }
+
+ private:
+  Previous previous_;
+};
+
+// A ShapePattern implementation that matches only if the shape is a tuple.
+template <typename Previous>
+class ShapePatternIsTupleImpl {
+ public:
+  explicit constexpr ShapePatternIsTupleImpl(const Previous& previous)
+      : previous_(previous) {}
+
+  bool Match(const ::xla::Shape* shape) const {
+    return previous_.Match(shape) && ShapeUtil::IsTuple(*shape);
+  }
+
+ private:
+  Previous previous_;
+};
+
+// A ShapePattern implementation that matches only if the shape has a given
+// rank.
+template <typename Previous>
+class ShapePatternRankImpl {
+ public:
+  explicit constexpr ShapePatternRankImpl(const Previous& previous, int64 rank)
+      : previous_(previous), rank_(rank) {}
+
+  bool Match(const ::xla::Shape* shape) const {
+    return previous_.Match(shape) && ShapeUtil::Rank(*shape) == rank_;
+  }
+
+ private:
+  Previous previous_;
+  int64 rank_;
+};
+
+// A ShapePattern implementation that matches only if the shape has a layout
+// that matches a given pattern.
+template <typename Previous, typename LayoutType, typename LayoutImpl>
+class ShapePatternLayoutImpl {
+ public:
+  explicit constexpr ShapePatternLayoutImpl(
+      const Previous& previous,
+      const LayoutPattern<LayoutType, LayoutImpl>& layout)
+      : previous_(previous), layout_(layout) {}
+
+  bool Match(const ::xla::Shape* shape) const {
+    return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) &&
+           layout_.Match(&shape->layout());
+  }
+
+  bool Match(Shape* shape) const {
+    return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) &&
+           layout_.Match(shape->mutable_layout());
+  }
+
+ private:
+  Previous previous_;
+  LayoutPattern<LayoutType, LayoutImpl> layout_;
+};
+
+// A ShapePattern implementation that matches only if the shape has a subshape
+// that matches a given pattern.
+template <typename Previous, typename SubshapeType, typename SubshapeImpl>
+class ShapePatternSubshapeImpl {
+ public:
+  explicit ShapePatternSubshapeImpl(
+      const Previous& previous, ShapeIndexView index,
+      const ShapePattern<SubshapeType, SubshapeImpl>& subshape)
+      : previous_(previous), index_(index), subshape_(subshape) {}
+
+  bool Match(const ::xla::Shape* shape) const {
+    return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) &&
+           subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_));
+  }
+
+  bool Match(::xla::Shape* shape) const {
+    return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) &&
+           subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_));
+  }
+
+ private:
+  Previous previous_;
+  ShapeIndexView index_;
+  ShapePattern<SubshapeType, SubshapeImpl> subshape_;
+};
+
+// A pattern that matches Shapes.
+template <typename ShapeType, typename Impl>
+class ShapePattern {
+ public:
+  explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape)
+      : impl_(impl), matched_shape_(matched_shape) {}
+
+  // Returns true and captures the shape iff it matches the pattern.
+  bool Match(const ::xla::Shape* shape) const {
+    if (impl_.Match(shape)) {
+      if (matched_shape_) {
+        *matched_shape_ = shape;
+      }
+      return true;
+    }
+    return false;
+  }
+
+  // Returns true and captures the shape iff it matches the pattern.
+  bool Match(::xla::Shape* shape) const {
+    if (impl_.Match(shape)) {
+      if (matched_shape_) {
+        *matched_shape_ = shape;
+      }
+      return true;
+    }
+    return false;
+  }
+
+  // Modifies the pattern to match only if the shape equals the given proto.
+  // The layout must outlive the returned pattern.
+  constexpr ShapePattern<ShapeType, ShapePatternEqualImpl<Impl>> EqualTo(
+      const ::xla::Shape* shape) const {
+    return ShapePattern<ShapeType, ShapePatternEqualImpl<Impl>>(
+        ShapePatternEqualImpl<Impl>(impl_, shape), matched_shape_);
+  }
+
+  // Modifies the pattern to match only if the shape is compatible to the given
+  // proto. The layout must outlive the returned pattern.
+  constexpr ShapePattern<ShapeType, ShapePatternCompatibleImpl<Impl>>
+  CompatibleTo(const ::xla::Shape* shape) const {
+    return ShapePattern<ShapeType, ShapePatternCompatibleImpl<Impl>>(
+        ShapePatternCompatibleImpl<Impl>(impl_, shape), matched_shape_);
+  }
+
+  // Modifies the pattern to match only if the shape has the given element type.
+  constexpr ShapePattern<ShapeType, ShapePatternElementTypeImpl<Impl>>
+  WithElementType(PrimitiveType element_type) const {
+    return ShapePattern<ShapeType, ShapePatternElementTypeImpl<Impl>>(
+        ShapePatternElementTypeImpl<Impl>(impl_, element_type), matched_shape_);
+  }
+
+  // Modifies the pattern to match only if the shape is scalar.
+  constexpr ShapePattern<ShapeType, ShapePatternIsScalarImpl<Impl>> IsScalar()
+      const {
+    return ShapePattern<ShapeType, ShapePatternIsScalarImpl<Impl>>(
+        ShapePatternIsScalarImpl<Impl>(impl_), matched_shape_);
+  }
+
+  // Modifies the pattern to match only if the shape is an array.
+  constexpr ShapePattern<ShapeType, ShapePatternIsArrayImpl<Impl>> IsArray()
+      const {
+    return ShapePattern<ShapeType, ShapePatternIsArrayImpl<Impl>>(
+        ShapePatternIsArrayImpl<Impl>(impl_), matched_shape_);
+  }
+
+  // Modifies the pattern to match only if the shape is a tuple.
+  constexpr ShapePattern<ShapeType, ShapePatternIsTupleImpl<Impl>> IsTuple()
+      const {
+    return ShapePattern<ShapeType, ShapePatternIsTupleImpl<Impl>>(
+        ShapePatternIsTupleImpl<Impl>(impl_), matched_shape_);
+  }
+
+  // Modifies the pattern to match only if the shape has the given rank.
+  constexpr ShapePattern<ShapeType, ShapePatternRankImpl<Impl>> WithRank(
+      int64 rank) const {
+    return ShapePattern<ShapeType, ShapePatternRankImpl<Impl>>(
+        ShapePatternRankImpl<Impl>(impl_, rank), matched_shape_);
+  }
+
+  // Modifies the pattern to match only if the shape has a layout that matches
+  // the given pattern.
+  template <typename LayoutType, typename LayoutImpl>
+  constexpr ShapePattern<ShapeType,
+                         ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>>
+  WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const {
+    return ShapePattern<ShapeType,
+                        ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>>(
+        ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>(impl_, layout),
+        matched_shape_);
+  }
+
+  constexpr ShapePattern<
+      ShapeType,
+      ShapePatternLayoutImpl<Impl, const ::xla::Layout,
+                             LayoutPatternEqualImpl<LayoutPatternBaseImpl>>>
+  WithLayoutEqualTo(const ::xla::Layout* layout) const {
+    return WithLayout(Layout().EqualTo(layout));
+  }
+
+  constexpr ShapePattern<
+      ShapeType,
+      ShapePatternLayoutImpl<Impl, const ::xla::Layout,
+                             LayoutPatternFormatImpl<LayoutPatternBaseImpl>>>
+  IsDenseArray(const ::xla::Layout* layout) const {
+    return WithLayout(Layout().WithDenseFormat());
+  }
+
+  constexpr ShapePattern<
+      ShapeType,
+      ShapePatternLayoutImpl<Impl, const ::xla::Layout,
+                             LayoutPatternFormatImpl<LayoutPatternBaseImpl>>>
+  IsSparseArray(const ::xla::Layout* layout) const {
+    return WithLayout(Layout().WithSparseFormat());
+  }
+
+  // Modifies the pattern to match only if the shape has a subshape that matches
+  // the given pattern.
+  template <typename SubshapeType, typename SubshapeImpl>
+  ShapePattern<ShapeType,
+               ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>>
+  WithSubshape(ShapeIndexView index,
+               const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const {
+    return ShapePattern<
+        ShapeType, ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>>(
+        ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>(impl_, index,
+                                                                   subshape),
+        matched_shape_);
+  }
+
+  ShapePattern<ShapeType, ShapePatternSubshapeImpl<
+                              Impl, const ::xla::Shape,
+                              ShapePatternEqualImpl<ShapePatternBaseImpl>>>
+  WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const {
+    return WithSubshape(index,
+                        ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
+                            ShapePatternBaseImpl(), nullptr)
+                            .EqualTo(shape));
+  }
+
+  ShapePattern<ShapeType, ShapePatternSubshapeImpl<
+                              Impl, const ::xla::Shape,
+                              ShapePatternCompatibleImpl<ShapePatternBaseImpl>>>
+  WithSubshapeCompatibleTo(ShapeIndexView index,
+                           const ::xla::Shape* shape) const {
+    return WithSubshape(index,
+                        ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>(
+                            ShapePatternBaseImpl(), nullptr)
+                            .CompatibleTo(shape));
+  }
+
+ private:
+  Impl impl_;
+  ShapeType** matched_shape_;
+};
+
+}  // namespace detail
+
+// Creates a shape pattern that will capture the matched layout in the argument.
+inline constexpr detail::ShapePattern<const ::xla::Shape,
+                                      detail::ShapePatternBaseImpl>
+Shape(const ::xla::Shape** matched_shape = nullptr) {
+  return detail::ShapePattern<const ::xla::Shape, detail::ShapePatternBaseImpl>(
+      detail::ShapePatternBaseImpl(), matched_shape);
+}
+
+// Creates a shape pattern that will capture the matched layout in the argument.
+inline constexpr detail::ShapePattern<::xla::Shape,
+                                      detail::ShapePatternBaseImpl>
+Shape(::xla::Shape** matched_shape) {
+  return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>(
+      detail::ShapePatternBaseImpl(), matched_shape);
+}
+
+namespace detail {
+
+template <typename HloInstructionType, typename Impl>
+class HloInstructionPattern;
+
+// The base HloInstructionPattern implementation. Matches only if the
+// instruction is not nullptr.
+class HloInstructionPatternBaseImpl {
+ public:
+  bool Match(const ::xla::HloInstruction* inst) const {
+    return inst != nullptr;
+  }
+};
+
+// An HloInstructionPattern implementation that matches only if the instruction
+// has a given name.
+template <typename Previous>
+class HloInstructionPatternNameImpl {
+ public:
+  explicit HloInstructionPatternNameImpl(const Previous& previous,
+                                         tensorflow::StringPiece name)
+      : previous_(previous), name_(name) {}
+
+  bool Match(const ::xla::HloInstruction* inst) const {
+    return previous_.Match(inst) && inst->name() == name_;
+  }
+
+ private:
+  Previous previous_;
+  tensorflow::StringPiece name_;
+};
+
+// An HloInstructionPattern implementation that matches only if the instruction
+// has a given opcode.
+template <typename Previous>
+class HloInstructionPatternOpcodeImpl {
+ public:
+  explicit constexpr HloInstructionPatternOpcodeImpl(const Previous& previous,
+                                                     HloOpcode opcode,
+                                                     bool invert)
+      : previous_(previous), opcode_(opcode), invert_(invert) {}
+
+  bool Match(const ::xla::HloInstruction* inst) const {
+    return previous_.Match(inst) && (invert_ ^ (inst->opcode() == opcode_));
+  }
+
+ private:
+  Previous previous_;
+  HloOpcode opcode_;
+  bool invert_;
+};
+
+// An HloInstructionPattern implementation that matches only if the instruction
+// has a shape that matches a given pattern.
+template <typename Previous, typename ShapeType, typename ShapeImpl>
+class HloInstructionPatternShapeImpl {
+ public:
+  explicit constexpr HloInstructionPatternShapeImpl(
+      const Previous& previous, const ShapePattern<ShapeType, ShapeImpl>& shape)
+      : previous_(previous), shape_(shape) {}
+
+  bool Match(const ::xla::HloInstruction* inst) const {
+    return previous_.Match(inst) && shape_.Match(&inst->shape());
+  }
+
+  bool Match(::xla::HloInstruction* inst) const {
+    return previous_.Match(inst) && shape_.Match(inst->mutable_shape());
+  }
+
+ private:
+  Previous previous_;
+  ShapePattern<ShapeType, ShapeImpl> shape_;
+};
+
+// An HloInstructionPattern implementation that matches only if the instruction
+// has an operand that matches a given pattern.
+template <typename Previous, typename OperandType, typename OperandImpl>
+class HloInstructionPatternOperandImpl {
+ public:
+  explicit constexpr HloInstructionPatternOperandImpl(
+      const Previous& previous, int64 operand_index,
+      const HloInstructionPattern<OperandType, OperandImpl>& operand)
+      : previous_(previous), operand_index_(operand_index), operand_(operand) {}
+
+  bool Match(const ::xla::HloInstruction* inst) const {
+    return previous_.Match(inst) && operand_index_ < inst->operand_count() &&
+           operand_.Match(inst->operand(operand_index_));
+  }
+
+  bool Match(::xla::HloInstruction* inst) const {
+    return previous_.Match(inst) && operand_index_ < inst->operand_count() &&
+           operand_.Match(inst->mutable_operand(operand_index_));
+  }
+
+ private:
+  Previous previous_;
+  int64 operand_index_;
+  HloInstructionPattern<OperandType, OperandImpl> operand_;
+};
+
+// A pattern that matches HloInstructions.
+template <typename HloInstructionType, typename Impl>
+class HloInstructionPattern {
+ public:
+  explicit constexpr HloInstructionPattern(const Impl& impl,
+                                           HloInstructionType** matched_inst)
+      : impl_(impl), matched_inst_(matched_inst) {}
+
+  // Returns true and captures the instruction iff it matches the pattern.
+  bool Match(const ::xla::HloInstruction* inst) const {
+    if (impl_.Match(inst)) {
+      if (matched_inst_) {
+        *matched_inst_ = inst;
+      }
+      return true;
+    }
+    return false;
+  }
+
+  // Returns true and captures the instruction iff it matches the pattern.
+  bool Match(::xla::HloInstruction* inst) const {
+    if (impl_.Match(inst)) {
+      if (matched_inst_) {
+        *matched_inst_ = inst;
+      }
+      return true;
+    }
+    return false;
+  }
+
+  // Modifies the pattern to match only if the instruction has the given name.
+  HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>>
+  WithName(tensorflow::StringPiece name) const {
+    return HloInstructionPattern<HloInstructionType,
+                                 HloInstructionPatternNameImpl<Impl>>(
+        HloInstructionPatternNameImpl<Impl>(impl_, name), matched_inst_);
+  }
+
+  // Modifies the pattern to match only if the instruction has the given opcode.
+  constexpr HloInstructionPattern<HloInstructionType,
+                                  HloInstructionPatternOpcodeImpl<Impl>>
+  WithOpcode(HloOpcode opcode) const {
+    return HloInstructionPattern<HloInstructionType,
+                                 HloInstructionPatternOpcodeImpl<Impl>>(
+        HloInstructionPatternOpcodeImpl<Impl>(impl_, opcode, false),
+        matched_inst_);
+  }
+
+  // Modifies the pattern to match only if the instruction does not have the
+  // given opcode.
+  constexpr HloInstructionPattern<HloInstructionType,
+                                  HloInstructionPatternOpcodeImpl<Impl>>
+  WithoutOpcode(HloOpcode opcode) const {
+    return HloInstructionPattern<HloInstructionType,
+                                 HloInstructionPatternOpcodeImpl<Impl>>(
+        HloInstructionPatternOpcodeImpl<Impl>(impl_, opcode, true),
+        matched_inst_);
+  }
+
+  // Modifies the pattern to match only if the instruction is a constant.
+  constexpr HloInstructionPattern<HloInstructionType,
+                                  HloInstructionPatternOpcodeImpl<Impl>>
+  IsConstant() const {
+    return WithOpcode(HloOpcode::kConstant);
+  }
+
+  // Modifies the pattern to match only if the instruction is not a constant.
+  constexpr HloInstructionPattern<HloInstructionType,
+                                  HloInstructionPatternOpcodeImpl<Impl>>
+  IsNonConstant() const {
+    return WithoutOpcode(HloOpcode::kConstant);
+  }
+
+  // Modifies the pattern to match only if the instruction has a shape that
+  // matches the given pattern.
+  template <typename ShapeType, typename ShapeImpl>
+  constexpr HloInstructionPattern<
+      HloInstructionType,
+      HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>>
+  WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape) const {
+    return HloInstructionPattern<
+        HloInstructionType,
+        HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>>(
+        HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>(impl_,
+                                                                   shape),
+        matched_inst_);
+  }
+
+  // Modifies the pattern to match only if the instruction has an operand that
+  // matches the given pattern.
+  template <typename OperandType, typename OperandImpl>
+  constexpr HloInstructionPattern<
+      HloInstructionType,
+      HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>>
+  WithOperand(
+      int64 operand_index,
+      const HloInstructionPattern<OperandType, OperandImpl>& operand) const {
+    return HloInstructionPattern<
+        HloInstructionType,
+        HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>>(
+        HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>(
+            impl_, operand_index, operand),
+        matched_inst_);
+  }
+
+ private:
+  Impl impl_;
+  HloInstructionType** matched_inst_;
+};
+
+}  // namespace detail
+
+// Creates an instruction pattern that will capture the matched instruction in
+// the argument.
+inline constexpr detail::HloInstructionPattern<
+    const ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl>
+Op(const ::xla::HloInstruction** matched_inst = nullptr) {
+  return detail::HloInstructionPattern<const ::xla::HloInstruction,
+                                       detail::HloInstructionPatternBaseImpl>(
+      detail::HloInstructionPatternBaseImpl(), matched_inst);
+}
+
+// Creates an instruction pattern that will capture the matched instruction in
+// the argument.
+inline constexpr detail::HloInstructionPattern<
+    ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl>
+Op(::xla::HloInstruction** matched_inst) {
+  return detail::HloInstructionPattern<::xla::HloInstruction,
+                                       detail::HloInstructionPatternBaseImpl>(
+      detail::HloInstructionPatternBaseImpl(), matched_inst);
+}
+
+// Helpers for nullary instructions.
+#define XLA_NULLOP_PATTERN(NAME)                                      \
+  inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
+    return Op().WithOpcode(HloOpcode::k##NAME);                       \
+  }                                                                   \
+                                                                      \
+  template <typename HloInstructionType>                              \
+  inline auto NAME(HloInstructionType** matched_inst)                 \
+      ->decltype(Op(matched_inst).WithOpcode(HloOpcode::k##NAME)) {   \
+    return Op(matched_inst).WithOpcode(HloOpcode::k##NAME);           \
+  }
+XLA_NULLOP_PATTERN(Constant)
+XLA_NULLOP_PATTERN(Infeed)
+XLA_NULLOP_PATTERN(Parameter)
+XLA_NULLOP_PATTERN(Recv)
+#undef XLA_NULLOP_PATTERN
+
+// Helpers for unary instructions.
+#define XLA_UNOP_PATTERN(NAME)                                        \
+  inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \
+    return Op().WithOpcode(HloOpcode::k##NAME);                       \
+  }                                                                   \
+                                                                      \
+  template <typename Arg>                                             \
+  inline auto NAME(Arg&& arg)->decltype(                              \
+      Op().WithOpcode(HloOpcode::k##NAME)                             \
+          .WithOperand(0, std::forward<Arg>(arg))) {                  \
+    return Op()                                                       \
+        .WithOpcode(HloOpcode::k##NAME)                               \
+        .WithOperand(0, std::forward<Arg>(arg));                      \
+  }                                                                   \
+                                                                      \
+  template <typename HloInstructionType, typename Arg>                \
+  inline auto NAME(HloInstructionType** matched_inst, Arg&& arg)      \
+      ->decltype(Op(matched_inst)                                     \
+                     .WithOpcode(HloOpcode::k##NAME)                  \
+                     .WithOperand(0, std::forward<Arg>(arg))) {       \
+    return Op(matched_inst)                                           \
+        .WithOpcode(HloOpcode::k##NAME)                               \
+        .WithOperand(0, std::forward<Arg>(arg));                      \
+  }
+XLA_UNOP_PATTERN(Abs)
+XLA_UNOP_PATTERN(RoundNearestAfz)
+XLA_UNOP_PATTERN(Bitcast)
+XLA_UNOP_PATTERN(Broadcast)
+XLA_UNOP_PATTERN(BroadcastDimOne)
+XLA_UNOP_PATTERN(Ceil)
+XLA_UNOP_PATTERN(Copy)
+XLA_UNOP_PATTERN(Cos)
+XLA_UNOP_PATTERN(Exp)
+XLA_UNOP_PATTERN(Fft)
+XLA_UNOP_PATTERN(Floor)
+XLA_UNOP_PATTERN(Imag)
+XLA_UNOP_PATTERN(IsFinite)
+XLA_UNOP_PATTERN(Log)
+XLA_UNOP_PATTERN(Not)
+XLA_UNOP_PATTERN(Negate)
+XLA_UNOP_PATTERN(Outfeed)
+XLA_UNOP_PATTERN(Real)
+XLA_UNOP_PATTERN(Reduce)
+XLA_UNOP_PATTERN(ReducePrecision)
+XLA_UNOP_PATTERN(Reshape)
+XLA_UNOP_PATTERN(Reverse)
+XLA_UNOP_PATTERN(Send)
+XLA_UNOP_PATTERN(Sign)
+XLA_UNOP_PATTERN(Sin)
+XLA_UNOP_PATTERN(Sort)
+XLA_UNOP_PATTERN(Tanh)
+XLA_UNOP_PATTERN(Transpose)
+#undef XLA_UNOP_PATTERN
+
+// Helpers for binary instructions.
+#define XLA_BINOP_PATTERN(NAME)                                             \
+  inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) {       \
+    return Op().WithOpcode(HloOpcode::k##NAME);                             \
+  }                                                                         \
+                                                                            \
+  template <typename Lhs, typename Rhs>                                     \
+  inline auto NAME(Lhs&& lhs, Rhs&& rhs)                                    \
+      ->decltype(Op().WithOpcode(HloOpcode::k##NAME)                        \
+                     .WithOperand(0, std::forward<Lhs>(lhs))                \
+                     .WithOperand(1, std::forward<Rhs>(rhs))) {             \
+    return Op()                                                             \
+        .WithOpcode(HloOpcode::k##NAME)                                     \
+        .WithOperand(0, std::forward<Lhs>(lhs))                             \
+        .WithOperand(1, std::forward<Rhs>(rhs));                            \
+  }                                                                         \
+                                                                            \
+  template <typename HloInstructionType, typename Lhs, typename Rhs>        \
+  inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \
+      ->decltype(Op(matched_inst)                                           \
+                     .WithOpcode(HloOpcode::k##NAME)                        \
+                     .WithOperand(0, std::forward<Lhs>(lhs))                \
+                     .WithOperand(1, std::forward<Rhs>(rhs))) {             \
+    return Op(matched_inst)                                                 \
+        .WithOpcode(HloOpcode::k##NAME)                                     \
+        .WithOperand(0, std::forward<Lhs>(lhs))                             \
+        .WithOperand(1, std::forward<Rhs>(rhs));                            \
+  }
+XLA_BINOP_PATTERN(Add)
+XLA_BINOP_PATTERN(Atan2)
+XLA_BINOP_PATTERN(Divide)
+XLA_BINOP_PATTERN(Complex)
+XLA_BINOP_PATTERN(Dot)
+XLA_BINOP_PATTERN(Eq)
+XLA_BINOP_PATTERN(Gather)
+XLA_BINOP_PATTERN(Ge)
+XLA_BINOP_PATTERN(Gt)
+XLA_BINOP_PATTERN(Le)
+XLA_BINOP_PATTERN(Lt)
+XLA_BINOP_PATTERN(Maximum)
+XLA_BINOP_PATTERN(Minimum)
+XLA_BINOP_PATTERN(Multiply)
+XLA_BINOP_PATTERN(Ne)
+XLA_BINOP_PATTERN(Power)
+XLA_BINOP_PATTERN(Remainder)
+XLA_BINOP_PATTERN(Subtract)
+XLA_BINOP_PATTERN(And)
+XLA_BINOP_PATTERN(Or)
+XLA_BINOP_PATTERN(ShiftLeft)
+XLA_BINOP_PATTERN(ShiftRightArithmetic)
+XLA_BINOP_PATTERN(ShiftRightLogical)
+#undef XLA_BINOP_PATTERN
+
+// Helpers for ternary instructions.
+#define XLA_TERNOP_PATTERN(NAME)                                       \
+  inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) {  \
+    return Op().WithOpcode(HloOpcode::k##NAME);                        \
+  }                                                                    \
+                                                                       \
+  template <typename Arg0, typename Arg1, typename Arg2>               \
+  inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2)              \
+      ->decltype(Op().WithOpcode(HloOpcode::k##NAME)                   \
+                     .WithOperand(0, std::forward<Arg0>(arg0))         \
+                     .WithOperand(1, std::forward<Arg1>(arg1))         \
+                     .WithOperand(2, std::forward<Arg2>(arg2))) {      \
+    return Op()                                                        \
+        .WithOpcode(HloOpcode::k##NAME)                                \
+        .WithOperand(0, std::forward<Arg0>(arg0))                      \
+        .WithOperand(1, std::forward<Arg1>(arg1))                      \
+        .WithOperand(2, std::forward<Arg2>(arg2));                     \
+  }                                                                    \
+                                                                       \
+  template <typename HloInstructionType, typename Arg0, typename Arg1, \
+            typename Arg2>                                             \
+  inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0,     \
+                   Arg1&& arg1, Arg2&& arg2)                           \
+      ->decltype(Op(matched_inst)                                      \
+                     .WithOpcode(HloOpcode::k##NAME)                   \
+                     .WithOperand(0, std::forward<Arg0>(arg0))         \
+                     .WithOperand(1, std::forward<Arg1>(arg1))         \
+                     .WithOperand(2, std::forward<Arg2>(arg2))) {      \
+    return Op(matched_inst)                                            \
+        .WithOpcode(HloOpcode::k##NAME)                                \
+        .WithOperand(0, std::forward<Arg0>(arg0))                      \
+        .WithOperand(1, std::forward<Arg1>(arg1))                      \
+        .WithOperand(2, std::forward<Arg2>(arg2));                     \
+  }
+XLA_TERNOP_PATTERN(Clamp);
+XLA_TERNOP_PATTERN(Select);
+#undef XLA_TERNOP_PATTERN
+
+// Helpers for matching non-constant instructions.
+inline auto NonConstant() -> decltype(Op().IsNonConstant()) {
+  return Op().IsNonConstant();
+}
+
+template <typename HloInstructionType>
+inline auto NonConstant(HloInstructionType** matched_inst)
+    -> decltype(Op(matched_inst).IsNonConstant()) {
+  return Op(matched_inst).IsNonConstant();
+}
+
+}  // namespace match
+
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
new file mode 100644 (file)
index 0000000..5291b14
--- /dev/null
@@ -0,0 +1,144 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/pattern_matcher.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+TEST(PatternMatcherTest, AddOp) {
+  constexpr char kModuleStr[] = R"(HloModule two_plus_two_module
+    ENTRY %two_plus_two_computation () -> f32[] {
+      %two = f32[] constant(2)
+      ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two)
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr));
+
+  const HloInstruction* matched_inst;
+  HloInstruction* matched_operand;
+  Shape* matched_shape;
+  Layout* matched_layout;
+
+  ASSERT_TRUE(Match(
+      hlo_module->entry_computation()->root_instruction(),
+      match::Op(&matched_inst)
+          .WithName("two_plus_two")
+          .WithOpcode(HloOpcode::kAdd)
+          .WithShape(
+              match::Shape(&matched_shape)
+                  .WithLayout(match::Layout(&matched_layout).WithDenseFormat()))
+          .WithOperand(
+              0,
+              match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant))));
+  ASSERT_NE(matched_inst, nullptr);
+  EXPECT_EQ(matched_inst->name(), "two_plus_two");
+  EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd);
+
+  EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(),
+                    match::Add(match::Constant(), match::Constant())));
+
+  EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
+                     match::Op().WithName("bad_name")));
+  matched_inst = nullptr;
+  EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(),
+                     match::Multiply(&matched_inst, match::Op(), match::Op())));
+}
+
+TEST(PatternMatcherTest, ScalarShape) {
+  auto scalar_shape = ShapeUtil::MakeShape(F32, {});
+  Shape* matched_shape;
+  EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar()));
+  EXPECT_EQ(matched_shape, &scalar_shape);
+  EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray()));
+  EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple()));
+  EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32)));
+  EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0)));
+  EXPECT_FALSE(Match(
+      &scalar_shape,
+      match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32)));
+}
+
+TEST(PatternMatcherTest, ArrayShape) {
+  auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4});
+  Shape* matched_shape;
+  EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray()));
+  EXPECT_EQ(matched_shape, &array_shape);
+  EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar()));
+  EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple()));
+  EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32)));
+  EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3)));
+  EXPECT_FALSE(
+      Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape())));
+  Layout* matched_layout;
+  EXPECT_FALSE(Match(&array_shape,
+                     match::Shape().WithLayout(
+                         match::Layout(&matched_layout).WithSparseFormat())));
+}
+
+TEST(PatternMatcherTest, TupleShape) {
+  auto tuple_shape = ShapeUtil::MakeTupleShape({
+      ShapeUtil::MakeShape(F32, {1, 2, 3}),
+      ShapeUtil::MakeShape(S32, {4, 5}),
+  });
+  EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple()));
+  EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray()));
+  EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar()));
+
+  Shape* subshape;
+  ASSERT_TRUE(Match(
+      &tuple_shape,
+      match::Shape().WithSubshape(
+          {0}, match::Shape(&subshape).WithElementType(F32).WithRank(3))));
+  ASSERT_NE(subshape, nullptr);
+  EXPECT_TRUE(
+      ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0})));
+  EXPECT_TRUE(Match(&tuple_shape,
+                    match::Shape().WithSubshape(
+                        {0}, match::Shape().EqualTo(
+                                 &ShapeUtil::GetSubshape(tuple_shape, {0})))));
+  EXPECT_FALSE(Match(&tuple_shape,
+                     match::Shape().WithSubshape(
+                         {0}, match::Shape().EqualTo(
+                                  &ShapeUtil::GetSubshape(tuple_shape, {1})))));
+
+  ASSERT_TRUE(Match(
+      &tuple_shape,
+      match::Shape().WithSubshape(
+          {1}, match::Shape(&subshape).WithElementType(S32).WithRank(2))));
+  ASSERT_NE(subshape, nullptr);
+  EXPECT_TRUE(
+      ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1})));
+  EXPECT_TRUE(Match(&tuple_shape,
+                    match::Shape().WithSubshape(
+                        {1}, match::Shape().EqualTo(
+                                 &ShapeUtil::GetSubshape(tuple_shape, {1})))));
+  EXPECT_FALSE(Match(&tuple_shape,
+                     match::Shape().WithSubshape(
+                         {1}, match::Shape().EqualTo(
+                                  &ShapeUtil::GetSubshape(tuple_shape, {0})))));
+
+  EXPECT_FALSE(
+      Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape())));
+  EXPECT_FALSE(
+      Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape())));
+}
+
+}  // namespace
+}  // namespace xla
index 6825d24..ac7e201 100644 (file)
@@ -824,6 +824,18 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
   return new_shape;
 }
 
+/* static */ bool ShapeUtil::IndexIsValid(const Shape& shape,
+                                          ShapeIndexView index) {
+  const Shape* subshape = &shape;
+  for (auto i : index) {
+    if (!IsTuple(*subshape) || i >= subshape->tuple_shapes_size()) {
+      return false;
+    }
+    subshape = &subshape->tuple_shapes(i);
+  }
+  return true;
+}
+
 /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
                                                  ShapeIndexView index) {
   const Shape* return_shape = &shape;
index 6d228ef..63da915 100644 (file)
@@ -448,6 +448,9 @@ class ShapeUtil {
   static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
                       std::initializer_list<int64> dimensions);
 
+  // Returns true if the given shape has a subshape at the given index.
+  static bool IndexIsValid(const Shape& shape, ShapeIndexView index);
+
   // GetSubshape and GetMutableSubshape return a particular nested Shape within
   // the given Shape argument.
   static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index);