From a22344f82ddd1e877f0b9f82584b9bb1d6c8dc16 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 13 Apr 2018 15:32:11 -0700 Subject: [PATCH] [XLA] Pattern matcher for HLO, Shapes, Layouts PiperOrigin-RevId: 192834129 --- tensorflow/compiler/xla/service/BUILD | 23 + tensorflow/compiler/xla/service/pattern_matcher.h | 1014 ++++++++++++++++++++ .../compiler/xla/service/pattern_matcher_test.cc | 144 +++ tensorflow/compiler/xla/shape_util.cc | 12 + tensorflow/compiler/xla/shape_util.h | 3 + 5 files changed, 1196 insertions(+) create mode 100644 tensorflow/compiler/xla/service/pattern_matcher.h create mode 100644 tensorflow/compiler/xla/service/pattern_matcher_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 65203fa..ddc0998 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -303,6 +303,29 @@ tf_cc_test( ) cc_library( + name = "pattern_matcher", + hdrs = ["pattern_matcher.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "pattern_matcher_test", + srcs = ["pattern_matcher_test.cc"], + deps = [ + ":hlo", + ":pattern_matcher", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:test", + ], +) + +cc_library( name = "hlo_reachability", srcs = ["hlo_reachability.cc"], hdrs = ["hlo_reachability.h"], diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h new file mode 100644 index 0000000..5d49638 --- /dev/null +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -0,0 +1,1014 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// A pattern matcher for HloInstructions, Shapes, and Layouts. +// +// The Match function's first argument must be HloInstruction*, Shape*, or +// Layout*. The second argument is a pattern that will be matched against the +// first argument, as described below. +// +// Patterns are constructed using the match::Op, match::Shape, or match::Layout +// functions. By default, the returned patterns will match any HloInstruction, +// Shape, or Layout, respectively. However the match can be made more specific +// by using the pattern's modifier methods, for example: +// +// match::Op().WithOpcode(HloOpcode::kAdd).WithOperand( +// 0, match::Op().WithOpcode(HloOpcode::kConstant)) +// +// This pattern will match Add instructions whose first operand is a constant. +// +// Each pattern type has the following modifiers: +// +// Op(): +// - WithName: match operations with the given name +// - WithOpcode: match operations with the given opcode +// - WithShape: match operations whose shape matches the given pattern +// - WithOperand: match operations whose operand matches the given pattern +// +// Shape(): +// - EqualTo: matches shapes that are equal to the argument +// - CompatibleTo: matches shapes that are compatible to the argument +// - IsScalar/IsArray/IsTuple: matches scalar/array/tuple shapes +// - IsDenseArray/IsSparseArray: matches arrays with dense/sparse format +// - WithLayout: match shapes whose layout matches the given pattern +// - WithLayoutEqualTo: matches shapes whose layouts equal the argument +// - WithSubshape: matches tuple shapes whose subshape matches the given +// pattern +// - WithSubshapeEqualTo: matches shapes with a subshape equal the argument +// - WithElementType: matches array/scalar shapes with the given element +// type +// - WithRank: matches array/scalar types with the given rank +// +// Layout(): +// - EqualTo: matches layouts that are equal to the argument +// - WithDenseFormat/WithSparseFormat: matches layouts with dense/sparse +// format +// +// Op(), Shape(), and Layout() may be passed an argument of type +// HloInstruction**, Shape**, or Layout**, respectively, or const versions of +// these pointers. If the pattern is matched, the address of the matched value +// will be "captured" and stored at this location. +// +// For example: +// HloInstruction* foo = ...; +// HloInstruction* matched_operand; +// CHECK(Match(foo, +// match::Op().WithOperand(0, match::Op(&matched_operand)))); +// +// Helpers are provided for common nullary, unary, binary, and ternary +// instructions. These helpers can be called with no arguments, in which case +// they will match any instruction matching the opcode. They may also be called +// with matches for the operands and with an optional capture. (The capture must +// be the first argument.) Some examples of these helpers and their equivalents +// are provided below. +// +// Example nullary instruction: +// Recv() == Op().WithOpcode(HloOpcode::kRecv) +// Recv(&a) == Op(&a).WithOpcode(HloOpcode::kRecv) +// +// Example unary instruction: +// Abs() == Op().WithOpcode(HloOpcode::kAbs) +// Abs(Op(&a)) == Op().WithOpcode(HloOpcode::kAbs) +// .WithOperand(0, Op(&a))) +// Abs(&a, Op(&b)) == Op(&a).WithOpcode(HloOpcode::kAbs) +// .WithOperand(0, Op(&b)) +// +// Example binary instruction: +// Add() == Op().WithOpcode(HloOpcode::kAdd) +// Add(Op(&a), Op(&b)) == Op().WithOpcode(HloOpcode::kAdd) +// .WithOperand(0, Op(&a)) +// .WithOperand(1, Op(&b)) +// Add(&a, Op(&b), Op(&c)) == Op(&a).WithOpcode(HloOpcode::kAdd) +// .WithOperand(0, Op(&b)) +// .WithOperand(1, Op(&c)) +// +// Example ternary instruction: +// Clamp() == Op().WithOpcode(HloOpcode::kClamp) +// Clamp(Op(&a), Op(&b), Op(&c)) == Op().WithOpcode(HloOpcode::kClamp) +// .WithOperand(0, Op(&a)) +// .WithOperand(1, Op(&b)) +// .WithOperand(2, Op(&c)) +// Clamp(&a, Op(&b), Op(&c), Op(&d)) == Op(&a).WithOpcode(HloOpcode::kClamp) +// .WithOperand(0, Op(&b)) +// .WithOperand(1, Op(&c)) +// .WithOperand(2, Op(&d)) +// +template +bool Match(Value* value, const Pattern& pattern) { + return pattern.Match(value); +} + +namespace match { + +namespace detail { + +template +class LayoutPattern; + +// The base LayoutPattern implementation. Matches only if the layout is not +// nullptr. +class LayoutPatternBaseImpl { + public: + bool Match(const ::xla::Layout* layout) const { return layout != nullptr; } +}; + +// A LayoutPattern implementation that matches only if the layout equals a +// Layout proto. +template +class LayoutPatternEqualImpl { + public: + explicit constexpr LayoutPatternEqualImpl(const Previous& previous, + const ::xla::Layout* layout) + : previous_(previous), layout_(layout) {} + + bool Match(const ::xla::Layout* layout) const { + return previous_.Match(layout) && LayoutUtil::Equal(*layout_, *layout); + } + + private: + Previous previous_; + const ::xla::Layout* layout_; +}; + +// A LayoutPattern implementation that matches only if the layout has a given +// format. +template +class LayoutPatternFormatImpl { + public: + explicit constexpr LayoutPatternFormatImpl(const Previous& previous, + Format format) + : previous_(previous), format_(format) {} + + bool Match(const ::xla::Layout* layout) const { + return previous_.Match(layout) && layout->format() == format_; + } + + private: + Previous previous_; + Format format_; +}; + +// A pattern that matches Layouts. +template +class LayoutPattern { + public: + explicit constexpr LayoutPattern(const Impl& impl, + LayoutType** matched_layout) + : impl_(impl), matched_layout_(matched_layout) {} + + // Returns true and captures the layout iff it matches the pattern. + bool Match(const ::xla::Layout* layout) const { + if (impl_.Match(layout)) { + if (matched_layout_) { + *matched_layout_ = layout; + } + return true; + } + return false; + } + + // Returns true and captures the layout iff it matches the pattern. + bool Match(::xla::Layout* layout) const { + if (impl_.Match(layout)) { + if (matched_layout_) { + *matched_layout_ = layout; + } + return true; + } + return false; + } + + // Modifies the pattern to match only if the layout equals the given proto. + // The layout must outlive the returned pattern. + constexpr LayoutPattern> EqualTo( + const Layout* layout) const { + return LayoutPattern>( + LayoutPatternEqualImpl(impl_, layout), matched_layout_); + } + + // Modifies the pattern to match only if the layout has a dense format. + constexpr LayoutPattern> + WithDenseFormat() const { + return LayoutPattern>( + LayoutPatternFormatImpl(impl_, DENSE), matched_layout_); + } + + // Modifies the pattern to match only if the layout has a sparse format. + constexpr LayoutPattern> + WithSparseFormat() const { + return LayoutPattern>( + LayoutPatternFormatImpl(impl_, SPARSE), matched_layout_); + } + + private: + Impl impl_; + LayoutType** matched_layout_; +}; + +} // namespace detail + +// Creates a layout pattern that will capture the matched layout in the +// argument. +inline constexpr detail::LayoutPattern +Layout(const ::xla::Layout** matched_layout = nullptr) { + return detail::LayoutPattern( + detail::LayoutPatternBaseImpl(), matched_layout); +} + +// Creates a layout pattern that will capture the matched layout in the +// argument. +inline constexpr detail::LayoutPattern<::xla::Layout, + detail::LayoutPatternBaseImpl> +Layout(::xla::Layout** matched_layout) { + return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>( + detail::LayoutPatternBaseImpl(), matched_layout); +} + +namespace detail { + +template +class ShapePattern; + +// The base ShapePattern implementation. Matches only if the shape is not +// nullptr. +class ShapePatternBaseImpl { + public: + bool Match(const ::xla::Shape* shape) const { return shape != nullptr; } +}; + +// A ShapePattern implementation that matches only if the shape equals a Shape +// proto. +template +class ShapePatternEqualImpl { + public: + explicit constexpr ShapePatternEqualImpl(const Previous& previous, + const ::xla::Shape* shape) + : previous_(previous), shape_(shape) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::Equal(*shape_, *shape); + } + + private: + Previous previous_; + const ::xla::Shape* shape_; +}; + +// A ShapePattern implementation that matches only if the shape is compatible to +// a Shape proto. +template +class ShapePatternCompatibleImpl { + public: + explicit constexpr ShapePatternCompatibleImpl(const Previous& previous, + const ::xla::Shape* shape) + : previous_(previous), shape_(shape) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::Compatible(*shape_, *shape); + } + + private: + Previous previous_; + const ::xla::Shape* shape_; +}; + +// A ShapePattern implementation that matches only if the shape has a given +// element type. +template +class ShapePatternElementTypeImpl { + public: + explicit constexpr ShapePatternElementTypeImpl(const Previous& previous, + PrimitiveType element_type) + : previous_(previous), element_type_(element_type) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && shape->element_type() == element_type_; + } + + private: + Previous previous_; + PrimitiveType element_type_; +}; + +// A ShapePattern implementation that matches only if the shape is scalar. +template +class ShapePatternIsScalarImpl { + public: + explicit constexpr ShapePatternIsScalarImpl(const Previous& previous) + : previous_(previous) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IsScalar(*shape); + } + + private: + Previous previous_; +}; + +// A ShapePattern implementation that matches only if the shape is an array +template +class ShapePatternIsArrayImpl { + public: + explicit constexpr ShapePatternIsArrayImpl(const Previous& previous) + : previous_(previous) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IsArray(*shape); + } + + private: + Previous previous_; +}; + +// A ShapePattern implementation that matches only if the shape is a tuple. +template +class ShapePatternIsTupleImpl { + public: + explicit constexpr ShapePatternIsTupleImpl(const Previous& previous) + : previous_(previous) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IsTuple(*shape); + } + + private: + Previous previous_; +}; + +// A ShapePattern implementation that matches only if the shape has a given +// rank. +template +class ShapePatternRankImpl { + public: + explicit constexpr ShapePatternRankImpl(const Previous& previous, int64 rank) + : previous_(previous), rank_(rank) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::Rank(*shape) == rank_; + } + + private: + Previous previous_; + int64 rank_; +}; + +// A ShapePattern implementation that matches only if the shape has a layout +// that matches a given pattern. +template +class ShapePatternLayoutImpl { + public: + explicit constexpr ShapePatternLayoutImpl( + const Previous& previous, + const LayoutPattern& layout) + : previous_(previous), layout_(layout) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && + layout_.Match(&shape->layout()); + } + + bool Match(Shape* shape) const { + return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && + layout_.Match(shape->mutable_layout()); + } + + private: + Previous previous_; + LayoutPattern layout_; +}; + +// A ShapePattern implementation that matches only if the shape has a subshape +// that matches a given pattern. +template +class ShapePatternSubshapeImpl { + public: + explicit ShapePatternSubshapeImpl( + const Previous& previous, ShapeIndexView index, + const ShapePattern& subshape) + : previous_(previous), index_(index), subshape_(subshape) {} + + bool Match(const ::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_)); + } + + bool Match(::xla::Shape* shape) const { + return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_)); + } + + private: + Previous previous_; + ShapeIndexView index_; + ShapePattern subshape_; +}; + +// A pattern that matches Shapes. +template +class ShapePattern { + public: + explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape) + : impl_(impl), matched_shape_(matched_shape) {} + + // Returns true and captures the shape iff it matches the pattern. + bool Match(const ::xla::Shape* shape) const { + if (impl_.Match(shape)) { + if (matched_shape_) { + *matched_shape_ = shape; + } + return true; + } + return false; + } + + // Returns true and captures the shape iff it matches the pattern. + bool Match(::xla::Shape* shape) const { + if (impl_.Match(shape)) { + if (matched_shape_) { + *matched_shape_ = shape; + } + return true; + } + return false; + } + + // Modifies the pattern to match only if the shape equals the given proto. + // The layout must outlive the returned pattern. + constexpr ShapePattern> EqualTo( + const ::xla::Shape* shape) const { + return ShapePattern>( + ShapePatternEqualImpl(impl_, shape), matched_shape_); + } + + // Modifies the pattern to match only if the shape is compatible to the given + // proto. The layout must outlive the returned pattern. + constexpr ShapePattern> + CompatibleTo(const ::xla::Shape* shape) const { + return ShapePattern>( + ShapePatternCompatibleImpl(impl_, shape), matched_shape_); + } + + // Modifies the pattern to match only if the shape has the given element type. + constexpr ShapePattern> + WithElementType(PrimitiveType element_type) const { + return ShapePattern>( + ShapePatternElementTypeImpl(impl_, element_type), matched_shape_); + } + + // Modifies the pattern to match only if the shape is scalar. + constexpr ShapePattern> IsScalar() + const { + return ShapePattern>( + ShapePatternIsScalarImpl(impl_), matched_shape_); + } + + // Modifies the pattern to match only if the shape is an array. + constexpr ShapePattern> IsArray() + const { + return ShapePattern>( + ShapePatternIsArrayImpl(impl_), matched_shape_); + } + + // Modifies the pattern to match only if the shape is a tuple. + constexpr ShapePattern> IsTuple() + const { + return ShapePattern>( + ShapePatternIsTupleImpl(impl_), matched_shape_); + } + + // Modifies the pattern to match only if the shape has the given rank. + constexpr ShapePattern> WithRank( + int64 rank) const { + return ShapePattern>( + ShapePatternRankImpl(impl_, rank), matched_shape_); + } + + // Modifies the pattern to match only if the shape has a layout that matches + // the given pattern. + template + constexpr ShapePattern> + WithLayout(const LayoutPattern& layout) const { + return ShapePattern>( + ShapePatternLayoutImpl(impl_, layout), + matched_shape_); + } + + constexpr ShapePattern< + ShapeType, + ShapePatternLayoutImpl>> + WithLayoutEqualTo(const ::xla::Layout* layout) const { + return WithLayout(Layout().EqualTo(layout)); + } + + constexpr ShapePattern< + ShapeType, + ShapePatternLayoutImpl>> + IsDenseArray(const ::xla::Layout* layout) const { + return WithLayout(Layout().WithDenseFormat()); + } + + constexpr ShapePattern< + ShapeType, + ShapePatternLayoutImpl>> + IsSparseArray(const ::xla::Layout* layout) const { + return WithLayout(Layout().WithSparseFormat()); + } + + // Modifies the pattern to match only if the shape has a subshape that matches + // the given pattern. + template + ShapePattern> + WithSubshape(ShapeIndexView index, + const ShapePattern& subshape) const { + return ShapePattern< + ShapeType, ShapePatternSubshapeImpl>( + ShapePatternSubshapeImpl(impl_, index, + subshape), + matched_shape_); + } + + ShapePattern>> + WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const { + return WithSubshape(index, + ShapePattern( + ShapePatternBaseImpl(), nullptr) + .EqualTo(shape)); + } + + ShapePattern>> + WithSubshapeCompatibleTo(ShapeIndexView index, + const ::xla::Shape* shape) const { + return WithSubshape(index, + ShapePattern( + ShapePatternBaseImpl(), nullptr) + .CompatibleTo(shape)); + } + + private: + Impl impl_; + ShapeType** matched_shape_; +}; + +} // namespace detail + +// Creates a shape pattern that will capture the matched layout in the argument. +inline constexpr detail::ShapePattern +Shape(const ::xla::Shape** matched_shape = nullptr) { + return detail::ShapePattern( + detail::ShapePatternBaseImpl(), matched_shape); +} + +// Creates a shape pattern that will capture the matched layout in the argument. +inline constexpr detail::ShapePattern<::xla::Shape, + detail::ShapePatternBaseImpl> +Shape(::xla::Shape** matched_shape) { + return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>( + detail::ShapePatternBaseImpl(), matched_shape); +} + +namespace detail { + +template +class HloInstructionPattern; + +// The base HloInstructionPattern implementation. Matches only if the +// instruction is not nullptr. +class HloInstructionPatternBaseImpl { + public: + bool Match(const ::xla::HloInstruction* inst) const { + return inst != nullptr; + } +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// has a given name. +template +class HloInstructionPatternNameImpl { + public: + explicit HloInstructionPatternNameImpl(const Previous& previous, + tensorflow::StringPiece name) + : previous_(previous), name_(name) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && inst->name() == name_; + } + + private: + Previous previous_; + tensorflow::StringPiece name_; +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// has a given opcode. +template +class HloInstructionPatternOpcodeImpl { + public: + explicit constexpr HloInstructionPatternOpcodeImpl(const Previous& previous, + HloOpcode opcode, + bool invert) + : previous_(previous), opcode_(opcode), invert_(invert) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && (invert_ ^ (inst->opcode() == opcode_)); + } + + private: + Previous previous_; + HloOpcode opcode_; + bool invert_; +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// has a shape that matches a given pattern. +template +class HloInstructionPatternShapeImpl { + public: + explicit constexpr HloInstructionPatternShapeImpl( + const Previous& previous, const ShapePattern& shape) + : previous_(previous), shape_(shape) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && shape_.Match(&inst->shape()); + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && shape_.Match(inst->mutable_shape()); + } + + private: + Previous previous_; + ShapePattern shape_; +}; + +// An HloInstructionPattern implementation that matches only if the instruction +// has an operand that matches a given pattern. +template +class HloInstructionPatternOperandImpl { + public: + explicit constexpr HloInstructionPatternOperandImpl( + const Previous& previous, int64 operand_index, + const HloInstructionPattern& operand) + : previous_(previous), operand_index_(operand_index), operand_(operand) {} + + bool Match(const ::xla::HloInstruction* inst) const { + return previous_.Match(inst) && operand_index_ < inst->operand_count() && + operand_.Match(inst->operand(operand_index_)); + } + + bool Match(::xla::HloInstruction* inst) const { + return previous_.Match(inst) && operand_index_ < inst->operand_count() && + operand_.Match(inst->mutable_operand(operand_index_)); + } + + private: + Previous previous_; + int64 operand_index_; + HloInstructionPattern operand_; +}; + +// A pattern that matches HloInstructions. +template +class HloInstructionPattern { + public: + explicit constexpr HloInstructionPattern(const Impl& impl, + HloInstructionType** matched_inst) + : impl_(impl), matched_inst_(matched_inst) {} + + // Returns true and captures the instruction iff it matches the pattern. + bool Match(const ::xla::HloInstruction* inst) const { + if (impl_.Match(inst)) { + if (matched_inst_) { + *matched_inst_ = inst; + } + return true; + } + return false; + } + + // Returns true and captures the instruction iff it matches the pattern. + bool Match(::xla::HloInstruction* inst) const { + if (impl_.Match(inst)) { + if (matched_inst_) { + *matched_inst_ = inst; + } + return true; + } + return false; + } + + // Modifies the pattern to match only if the instruction has the given name. + HloInstructionPattern> + WithName(tensorflow::StringPiece name) const { + return HloInstructionPattern>( + HloInstructionPatternNameImpl(impl_, name), matched_inst_); + } + + // Modifies the pattern to match only if the instruction has the given opcode. + constexpr HloInstructionPattern> + WithOpcode(HloOpcode opcode) const { + return HloInstructionPattern>( + HloInstructionPatternOpcodeImpl(impl_, opcode, false), + matched_inst_); + } + + // Modifies the pattern to match only if the instruction does not have the + // given opcode. + constexpr HloInstructionPattern> + WithoutOpcode(HloOpcode opcode) const { + return HloInstructionPattern>( + HloInstructionPatternOpcodeImpl(impl_, opcode, true), + matched_inst_); + } + + // Modifies the pattern to match only if the instruction is a constant. + constexpr HloInstructionPattern> + IsConstant() const { + return WithOpcode(HloOpcode::kConstant); + } + + // Modifies the pattern to match only if the instruction is not a constant. + constexpr HloInstructionPattern> + IsNonConstant() const { + return WithoutOpcode(HloOpcode::kConstant); + } + + // Modifies the pattern to match only if the instruction has a shape that + // matches the given pattern. + template + constexpr HloInstructionPattern< + HloInstructionType, + HloInstructionPatternShapeImpl> + WithShape(const ShapePattern& shape) const { + return HloInstructionPattern< + HloInstructionType, + HloInstructionPatternShapeImpl>( + HloInstructionPatternShapeImpl(impl_, + shape), + matched_inst_); + } + + // Modifies the pattern to match only if the instruction has an operand that + // matches the given pattern. + template + constexpr HloInstructionPattern< + HloInstructionType, + HloInstructionPatternOperandImpl> + WithOperand( + int64 operand_index, + const HloInstructionPattern& operand) const { + return HloInstructionPattern< + HloInstructionType, + HloInstructionPatternOperandImpl>( + HloInstructionPatternOperandImpl( + impl_, operand_index, operand), + matched_inst_); + } + + private: + Impl impl_; + HloInstructionType** matched_inst_; +}; + +} // namespace detail + +// Creates an instruction pattern that will capture the matched instruction in +// the argument. +inline constexpr detail::HloInstructionPattern< + const ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> +Op(const ::xla::HloInstruction** matched_inst = nullptr) { + return detail::HloInstructionPattern( + detail::HloInstructionPatternBaseImpl(), matched_inst); +} + +// Creates an instruction pattern that will capture the matched instruction in +// the argument. +inline constexpr detail::HloInstructionPattern< + ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> +Op(::xla::HloInstruction** matched_inst) { + return detail::HloInstructionPattern<::xla::HloInstruction, + detail::HloInstructionPatternBaseImpl>( + detail::HloInstructionPatternBaseImpl(), matched_inst); +} + +// Helpers for nullary instructions. +#define XLA_NULLOP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst) \ + ->decltype(Op(matched_inst).WithOpcode(HloOpcode::k##NAME)) { \ + return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ + } +XLA_NULLOP_PATTERN(Constant) +XLA_NULLOP_PATTERN(Infeed) +XLA_NULLOP_PATTERN(Parameter) +XLA_NULLOP_PATTERN(Recv) +#undef XLA_NULLOP_PATTERN + +// Helpers for unary instructions. +#define XLA_UNOP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Arg&& arg)->decltype( \ + Op().WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg))) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg)); \ + } +XLA_UNOP_PATTERN(Abs) +XLA_UNOP_PATTERN(RoundNearestAfz) +XLA_UNOP_PATTERN(Bitcast) +XLA_UNOP_PATTERN(Broadcast) +XLA_UNOP_PATTERN(BroadcastDimOne) +XLA_UNOP_PATTERN(Ceil) +XLA_UNOP_PATTERN(Copy) +XLA_UNOP_PATTERN(Cos) +XLA_UNOP_PATTERN(Exp) +XLA_UNOP_PATTERN(Fft) +XLA_UNOP_PATTERN(Floor) +XLA_UNOP_PATTERN(Imag) +XLA_UNOP_PATTERN(IsFinite) +XLA_UNOP_PATTERN(Log) +XLA_UNOP_PATTERN(Not) +XLA_UNOP_PATTERN(Negate) +XLA_UNOP_PATTERN(Outfeed) +XLA_UNOP_PATTERN(Real) +XLA_UNOP_PATTERN(Reduce) +XLA_UNOP_PATTERN(ReducePrecision) +XLA_UNOP_PATTERN(Reshape) +XLA_UNOP_PATTERN(Reverse) +XLA_UNOP_PATTERN(Send) +XLA_UNOP_PATTERN(Sign) +XLA_UNOP_PATTERN(Sin) +XLA_UNOP_PATTERN(Sort) +XLA_UNOP_PATTERN(Tanh) +XLA_UNOP_PATTERN(Transpose) +#undef XLA_UNOP_PATTERN + +// Helpers for binary instructions. +#define XLA_BINOP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs))) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)); \ + } +XLA_BINOP_PATTERN(Add) +XLA_BINOP_PATTERN(Atan2) +XLA_BINOP_PATTERN(Divide) +XLA_BINOP_PATTERN(Complex) +XLA_BINOP_PATTERN(Dot) +XLA_BINOP_PATTERN(Eq) +XLA_BINOP_PATTERN(Gather) +XLA_BINOP_PATTERN(Ge) +XLA_BINOP_PATTERN(Gt) +XLA_BINOP_PATTERN(Le) +XLA_BINOP_PATTERN(Lt) +XLA_BINOP_PATTERN(Maximum) +XLA_BINOP_PATTERN(Minimum) +XLA_BINOP_PATTERN(Multiply) +XLA_BINOP_PATTERN(Ne) +XLA_BINOP_PATTERN(Power) +XLA_BINOP_PATTERN(Remainder) +XLA_BINOP_PATTERN(Subtract) +XLA_BINOP_PATTERN(And) +XLA_BINOP_PATTERN(Or) +XLA_BINOP_PATTERN(ShiftLeft) +XLA_BINOP_PATTERN(ShiftRightArithmetic) +XLA_BINOP_PATTERN(ShiftRightLogical) +#undef XLA_BINOP_PATTERN + +// Helpers for ternary instructions. +#define XLA_TERNOP_PATTERN(NAME) \ + inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ + return Op().WithOpcode(HloOpcode::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) \ + ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2))) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \ + Arg1&& arg1, Arg2&& arg2) \ + ->decltype(Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2))) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2)); \ + } +XLA_TERNOP_PATTERN(Clamp); +XLA_TERNOP_PATTERN(Select); +#undef XLA_TERNOP_PATTERN + +// Helpers for matching non-constant instructions. +inline auto NonConstant() -> decltype(Op().IsNonConstant()) { + return Op().IsNonConstant(); +} + +template +inline auto NonConstant(HloInstructionType** matched_inst) + -> decltype(Op(matched_inst).IsNonConstant()) { + return Op(matched_inst).IsNonConstant(); +} + +} // namespace match + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc new file mode 100644 index 0000000..5291b14 --- /dev/null +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +TEST(PatternMatcherTest, AddOp) { + constexpr char kModuleStr[] = R"(HloModule two_plus_two_module + ENTRY %two_plus_two_computation () -> f32[] { + %two = f32[] constant(2) + ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + + const HloInstruction* matched_inst; + HloInstruction* matched_operand; + Shape* matched_shape; + Layout* matched_layout; + + ASSERT_TRUE(Match( + hlo_module->entry_computation()->root_instruction(), + match::Op(&matched_inst) + .WithName("two_plus_two") + .WithOpcode(HloOpcode::kAdd) + .WithShape( + match::Shape(&matched_shape) + .WithLayout(match::Layout(&matched_layout).WithDenseFormat())) + .WithOperand( + 0, + match::Op(&matched_operand).WithOpcode(HloOpcode::kConstant)))); + ASSERT_NE(matched_inst, nullptr); + EXPECT_EQ(matched_inst->name(), "two_plus_two"); + EXPECT_EQ(matched_inst->opcode(), HloOpcode::kAdd); + + EXPECT_TRUE(Match(hlo_module->entry_computation()->root_instruction(), + match::Add(match::Constant(), match::Constant()))); + + EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(), + match::Op().WithName("bad_name"))); + matched_inst = nullptr; + EXPECT_FALSE(Match(hlo_module->entry_computation()->root_instruction(), + match::Multiply(&matched_inst, match::Op(), match::Op()))); +} + +TEST(PatternMatcherTest, ScalarShape) { + auto scalar_shape = ShapeUtil::MakeShape(F32, {}); + Shape* matched_shape; + EXPECT_TRUE(Match(&scalar_shape, match::Shape(&matched_shape).IsScalar())); + EXPECT_EQ(matched_shape, &scalar_shape); + EXPECT_TRUE(Match(&scalar_shape, match::Shape().IsArray())); + EXPECT_FALSE(Match(&scalar_shape, match::Shape().IsTuple())); + EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithElementType(F32))); + EXPECT_TRUE(Match(&scalar_shape, match::Shape().WithRank(0))); + EXPECT_FALSE(Match( + &scalar_shape, + match::Shape().WithSubshape({0}, match::Shape()).WithElementType(F32))); +} + +TEST(PatternMatcherTest, ArrayShape) { + auto array_shape = ShapeUtil::MakeShape(F32, {2, 3, 4}); + Shape* matched_shape; + EXPECT_TRUE(Match(&array_shape, match::Shape(&matched_shape).IsArray())); + EXPECT_EQ(matched_shape, &array_shape); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsScalar())); + EXPECT_FALSE(Match(&array_shape, match::Shape().IsTuple())); + EXPECT_TRUE(Match(&array_shape, match::Shape().WithElementType(F32))); + EXPECT_TRUE(Match(&array_shape, match::Shape().WithRank(3))); + EXPECT_FALSE( + Match(&array_shape, match::Shape().WithSubshape({0}, match::Shape()))); + Layout* matched_layout; + EXPECT_FALSE(Match(&array_shape, + match::Shape().WithLayout( + match::Layout(&matched_layout).WithSparseFormat()))); +} + +TEST(PatternMatcherTest, TupleShape) { + auto tuple_shape = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {1, 2, 3}), + ShapeUtil::MakeShape(S32, {4, 5}), + }); + EXPECT_TRUE(Match(&tuple_shape, match::Shape().IsTuple())); + EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsArray())); + EXPECT_FALSE(Match(&tuple_shape, match::Shape().IsScalar())); + + Shape* subshape; + ASSERT_TRUE(Match( + &tuple_shape, + match::Shape().WithSubshape( + {0}, match::Shape(&subshape).WithElementType(F32).WithRank(3)))); + ASSERT_NE(subshape, nullptr); + EXPECT_TRUE( + ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {0}))); + EXPECT_TRUE(Match(&tuple_shape, + match::Shape().WithSubshape( + {0}, match::Shape().EqualTo( + &ShapeUtil::GetSubshape(tuple_shape, {0}))))); + EXPECT_FALSE(Match(&tuple_shape, + match::Shape().WithSubshape( + {0}, match::Shape().EqualTo( + &ShapeUtil::GetSubshape(tuple_shape, {1}))))); + + ASSERT_TRUE(Match( + &tuple_shape, + match::Shape().WithSubshape( + {1}, match::Shape(&subshape).WithElementType(S32).WithRank(2)))); + ASSERT_NE(subshape, nullptr); + EXPECT_TRUE( + ShapeUtil::Equal(*subshape, ShapeUtil::GetSubshape(tuple_shape, {1}))); + EXPECT_TRUE(Match(&tuple_shape, + match::Shape().WithSubshape( + {1}, match::Shape().EqualTo( + &ShapeUtil::GetSubshape(tuple_shape, {1}))))); + EXPECT_FALSE(Match(&tuple_shape, + match::Shape().WithSubshape( + {1}, match::Shape().EqualTo( + &ShapeUtil::GetSubshape(tuple_shape, {0}))))); + + EXPECT_FALSE( + Match(&tuple_shape, match::Shape().WithSubshape({2}, match::Shape()))); + EXPECT_FALSE( + Match(&tuple_shape, match::Shape().WithSubshape({0, 0}, match::Shape()))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 6825d24..ac7e201 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -824,6 +824,18 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return new_shape; } +/* static */ bool ShapeUtil::IndexIsValid(const Shape& shape, + ShapeIndexView index) { + const Shape* subshape = &shape; + for (auto i : index) { + if (!IsTuple(*subshape) || i >= subshape->tuple_shapes_size()) { + return false; + } + subshape = &subshape->tuple_shapes(i); + } + return true; +} + /* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape, ShapeIndexView index) { const Shape* return_shape = &shape; diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 6d228ef..63da915 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -448,6 +448,9 @@ class ShapeUtil { static bool ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions); + // Returns true if the given shape has a subshape at the given index. + static bool IndexIsValid(const Shape& shape, ShapeIndexView index); + // GetSubshape and GetMutableSubshape return a particular nested Shape within // the given Shape argument. static const Shape& GetSubshape(const Shape& shape, ShapeIndexView index); -- 2.7.4