Implement unrolling of vector ops to finer-grained vector ops as a pattern.
authorNicolas Vasilache <ntv@google.com>
Wed, 20 Nov 2019 18:54:45 +0000 (10:54 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 20 Nov 2019 19:49:36 +0000 (11:49 -0800)
This CL uses the pattern rewrite infrastructure to implement a simple VectorOps -> VectorOps legalization strategy to unroll coarse-grained vector operations into finer grained ones.
The transformation is written using local pattern rewrites to allow composition with other rewrites. It proceeds by iteratively introducing fake cast ops and cleaning canonicalizing or lowering them away where appropriate.

This is an example of writing transformations as compositions of local pattern rewrites that should enable us to make them significantly more declarative.

PiperOrigin-RevId: 281555100

13 files changed:
mlir/include/mlir/Analysis/VectorAnalysis.h
mlir/include/mlir/Conversion/VectorConversions/VectorConversions.h
mlir/include/mlir/Dialect/VectorOps/CMakeLists.txt
mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td [new file with mode: 0644]
mlir/lib/Analysis/VectorAnalysis.cpp
mlir/lib/Conversion/VectorConversions/CMakeLists.txt
mlir/lib/Conversion/VectorConversions/VectorToVector.cpp [new file with mode: 0644]
mlir/lib/Transforms/MaterializeVectors.cpp
mlir/test/Conversion/VectorConversions/vector-to-vector.mlir [new file with mode: 0644]
mlir/test/lib/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp [new file with mode: 0644]
mlir/test/lib/Transforms/TestVectorizationUtils.cpp
mlir/tools/mlir-opt/CMakeLists.txt

index 8b9992d..350bdfd 100644 (file)
@@ -46,14 +46,14 @@ class VectorType;
 ///   - shapeRatio({3, 4, 5, 8}, {2, 5, 2}) returns {3, 2, 1, 4}
 ///   - shapeRatio({3, 4, 4, 8}, {2, 5, 2}) returns None
 ///   - shapeRatio({1, 2, 10, 32}, {2, 5, 2}) returns {1, 1, 2, 16}
-llvm::Optional<llvm::SmallVector<unsigned, 4>>
+llvm::Optional<llvm::SmallVector<int64_t, 4>>
 shapeRatio(ArrayRef<int64_t> superShape, ArrayRef<int64_t> subShape);
 
 /// Computes and returns the multi-dimensional ratio of the shapes of
 /// `superVector` to `subVector`. If integral division is not possible, returns
 /// None.
 /// Assumes and enforces that the VectorTypes have the same elemental type.
-llvm::Optional<llvm::SmallVector<unsigned, 4>>
+llvm::Optional<llvm::SmallVector<int64_t, 4>>
 shapeRatio(VectorType superVectorType, VectorType subVectorType);
 
 /// Constructs a permutation map of invariant memref indices to vector
index 33234b6..56862ca 100644 (file)
 // See the License for the specific language governing permissions and
 // limitations under the License.
 // =============================================================================
-#ifndef MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
-#define MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
+#ifndef MLIR_CONVERSION_VECTORCONVERSIONS_VECTORCONVERSIONS_H_
+#define MLIR_CONVERSION_VECTORCONVERSIONS_VECTORCONVERSIONS_H_
+
+#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 class LLVMTypeConverter;
 class MLIRContext;
 class ModuleOp;
 template <typename T> class OpPassBase;
-class OwningRewritePatternList;
 
 /// Collect a set of patterns to convert from the Vector dialect to affine loops
 /// surrounding ops in different dialects (vector, std etc).
@@ -31,6 +32,13 @@ class OwningRewritePatternList;
 void populateVectorToAffineLoopsConversionPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns);
 
+/// Collect a set of patterns to convert from the Vector dialect to itself.
+/// Should be merged with populateVectorToAffineLoopsConversionPatterns.
+void populateVectorToVectorConversionPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns,
+    ArrayRef<int64_t> coarseVectorShape = {},
+    ArrayRef<int64_t> fineVectorShape = {});
+
 /// Collect a set of patterns to convert from the Vector dialect to LLVM.
 void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                             OwningRewritePatternList &patterns);
@@ -40,4 +48,4 @@ OpPassBase<ModuleOp> *createLowerVectorToLLVMPass();
 
 } // namespace mlir
 
-#endif // MLIR_CONVERSION_VECTORTOLLVM_VECTORTOLLVM_H_
+#endif // MLIR_CONVERSION_VECTORCONVERSIONS_VECTORCONVERSIONS_H_
index 6cc7e44..3849dd7 100644 (file)
@@ -2,3 +2,7 @@ set(LLVM_TARGET_DEFINITIONS VectorOps.td)
 mlir_tablegen(VectorOps.h.inc -gen-op-decls)
 mlir_tablegen(VectorOps.cpp.inc -gen-op-defs)
 add_public_tablegen_target(MLIRVectorOpsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS VectorTransformPatterns.td)
+mlir_tablegen(VectorTransformPatterns.h.inc -gen-rewriters)
+add_public_tablegen_target(MLIRVectorTransformPatternsIncGen)
diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td b/mlir/include/mlir/Dialect/VectorOps/VectorTransformPatterns.td
new file mode 100644 (file)
index 0000000..fe0940c
--- /dev/null
@@ -0,0 +1,43 @@
+//===- VectorTransformPatterns.td - Vector-Vector patterns -*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the pattern definition file for declarative Vector transformations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef VECTOR_TRANSFORMS
+#define VECTOR_TRANSFORMS
+
+include "mlir/Dialect/StandardOps/Ops.td"
+include "mlir/Dialect/VectorOps/VectorOps.td"
+
+class HasShape<list<int> shape> :
+  CPred<"hasShape($0, {" #  StrJoinInt<shape>.result # "})">;
+
+class UnrollVectorOp<list<int> factors> : NativeCodeCall<
+  "unrollSingleResultOpMatchingType($_builder, $0->getDefiningOp(), " #
+    "{" # StrJoinInt<factors>.result # "})">;
+
+def : Pat<(AddFOp:$op_results $a, $b),
+          (UnrollVectorOp<[2, 2]> $op_results, $a, $b),
+          [(Constraint<HasShape<[4, 2]>> $a)]>;
+
+def : Pat<(AddFOp:$op_results $a, $b),
+          (UnrollVectorOp<[2, 2]> $op_results, $a, $b),
+          [(Constraint<HasShape<[4, 4]>> $a)]>;
+
+#endif // VECTOR_TRANSFORMS
index 2dab348..1c028b4 100644 (file)
@@ -39,15 +39,15 @@ using namespace mlir;
 
 using llvm::SetVector;
 
-Optional<SmallVector<unsigned, 4>>
-mlir::shapeRatio(ArrayRef<int64_t> superShape, ArrayRef<int64_t> subShape) {
+Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape,
+                                                   ArrayRef<int64_t> subShape) {
   if (superShape.size() < subShape.size()) {
-    return Optional<SmallVector<unsigned, 4>>();
+    return Optional<SmallVector<int64_t, 4>>();
   }
 
   // Starting from the end, compute the integer divisors.
   // Set the boolean `divides` if integral division is not possible.
-  std::vector<unsigned> result;
+  std::vector<int64_t> result;
   result.reserve(superShape.size());
   bool divides = true;
   auto divide = [&divides, &result](int superSize, int subSize) {
@@ -76,11 +76,11 @@ mlir::shapeRatio(ArrayRef<int64_t> superShape, ArrayRef<int64_t> subShape) {
          "super to sub shape ratio is not of the same size as the super rank");
 
   // Reverse again to get it back in the proper order and return.
-  return SmallVector<unsigned, 4>{result.rbegin(), result.rend()};
+  return SmallVector<int64_t, 4>{result.rbegin(), result.rend()};
 }
 
-Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType,
-                                                    VectorType subVectorType) {
+Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(VectorType superVectorType,
+                                                   VectorType subVectorType) {
   assert(superVectorType.getElementType() == subVectorType.getElementType() &&
          "vector types must be of the same elemental type");
   return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
index f76b413..c8d699e 100644 (file)
@@ -1,6 +1,7 @@
-add_llvm_library(MLIRVectorToLLVM
+add_llvm_library(MLIRVectorConversions
   VectorToLLVM.cpp
   VectorToLoops.cpp
+  VectorToVector.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorConversions
@@ -12,5 +13,6 @@ set(LIBS
   LLVMSupport
   )
 
-add_dependencies(MLIRVectorToLLVM ${LIBS})
-target_link_libraries(MLIRVectorToLLVM ${LIBS})
+add_dependencies(MLIRVectorConversions ${LIBS})
+add_dependencies(MLIRVectorConversions MLIRVectorTransformPatternsIncGen)
+target_link_libraries(MLIRVectorConversions ${LIBS})
diff --git a/mlir/lib/Conversion/VectorConversions/VectorToVector.cpp b/mlir/lib/Conversion/VectorConversions/VectorToVector.cpp
new file mode 100644 (file)
index 0000000..7cc8083
--- /dev/null
@@ -0,0 +1,397 @@
+//===- VectorToLoops.cpp - Conversion within the Vector dialect -----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements target-independent rewrites as 1->N patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include <type_traits>
+
+#include "mlir/Analysis/VectorAnalysis.h"
+#include "mlir/Conversion/VectorConversions/VectorConversions.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/Dialect/VectorOps/VectorOps.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Helpers.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/Functional.h"
+#include "mlir/Support/STLExtras.h"
+
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
+
+#define DEBUG_TYPE "vector-to-vector"
+
+using namespace mlir;
+using llvm::dbgs;
+using mlir::functional::zipMap;
+
+/// Given a shape with sizes greater than 0 along all dimensions,
+/// returns the distance, in number of elements, between a slice in a dimension
+/// and the next slice in the same dimension.
+///   e.g. shape[3, 4, 5] -> linearization_basis[20, 5, 1]
+static SmallVector<int64_t, 8> computeStrides(ArrayRef<int64_t> shape) {
+  if (shape.empty())
+    return {};
+  SmallVector<int64_t, 8> tmp;
+  tmp.reserve(shape.size());
+  int64_t running = 1;
+  for (auto size : llvm::reverse(shape)) {
+    assert(size > 0 && "size must be nonnegative");
+    tmp.push_back(running);
+    running *= size;
+  }
+  return SmallVector<int64_t, 8>(tmp.rbegin(), tmp.rend());
+}
+
+static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
+  if (basis.empty())
+    return 0;
+  int64_t res = 1;
+  for (auto b : basis)
+    res *= b;
+  return res;
+}
+
+/// Given a shape with sizes greater than 0 along all dimensions, returns the
+/// delinearized components of linearIndex along shape.
+static SmallVector<int64_t, 8> delinearize(int64_t linearIndex,
+                                           ArrayRef<int64_t> basis) {
+  SmallVector<int64_t, 8> res;
+  res.reserve(basis.size());
+  for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) {
+    assert(basis[idx] > 0);
+    res.push_back(linearIndex / basis[idx]);
+    linearIndex %= basis[idx];
+  }
+  // Sanity check.
+  assert(linearIndex == 0 && "linear index remainder must be 0");
+  return res;
+}
+
+static constexpr auto kFakeForkOp = "__fake_fork__";
+static constexpr auto kFakeJoinOp = "__fake_join__";
+static constexpr auto kUnrollAttrName = "__unroll__";
+static constexpr auto kBaseCoordAttrName = "__base_coord__";
+
+// Reads the IntegerArray attribute named `kUnrollAttrName` from `op` and
+// returns its representation as a vector of integers.
+static SmallVector<int64_t, 8> extractUnrollFactors(Operation *op) {
+  SmallVector<int64_t, 8> res;
+  auto unrollAttr = op->getAttr(kUnrollAttrName);
+  if (!unrollAttr)
+    return res;
+  auto unrollArrayAttr = unrollAttr.cast<ArrayAttr>();
+  res.reserve(unrollArrayAttr.size());
+  for (auto attr : unrollArrayAttr) {
+    auto unroll = attr.cast<IntegerAttr>().getValue().getSExtValue();
+    assert(unroll > 0);
+    res.push_back(unroll);
+  }
+  return res;
+}
+
+// Creates a custom `kFakeForkOp` used in progressive lowering to other vector
+// operations.
+static Operation *createFakeForkOp(PatternRewriter &builder, Location loc,
+                                   Value *operand, ArrayRef<Type> resultTypes,
+                                   ArrayRef<int64_t> unrollFactors = {}) {
+  OperationState *forkOp =
+      new OperationState(loc, kFakeForkOp, operand, resultTypes, {});
+  if (!unrollFactors.empty())
+    forkOp->addAttribute(kUnrollAttrName,
+                         builder.getI64ArrayAttr(unrollFactors));
+  return builder.createOperation(*forkOp);
+}
+
+// Creates a custom `kFakeJoinOp` used in progressive lowering to other vector
+// operations.
+static Operation *createFakeJoinOp(PatternRewriter &builder, Location loc,
+                                   ArrayRef<Value *> operands, Type resultType,
+                                   ArrayRef<int64_t> unrollFactors = {},
+                                   ArrayRef<int64_t> baseCoords = {}) {
+  OperationState *joinOp =
+      new OperationState(loc, kFakeJoinOp, operands, resultType, {});
+  if (!unrollFactors.empty())
+    joinOp->addAttribute(kUnrollAttrName,
+                         builder.getI64ArrayAttr(unrollFactors));
+  if (!baseCoords.empty())
+    joinOp->addAttribute(kBaseCoordAttrName,
+                         builder.getI64ArrayAttr(baseCoords));
+  return builder.createOperation(*joinOp);
+}
+
+// Clones `op` into a new operations that takes `operands` and returns
+// `resultTypes`.
+static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
+                                              Location loc, Operation *op,
+                                              ArrayRef<Value *> operands,
+                                              ArrayRef<Type> resultTypes) {
+  OperationState *res = new OperationState(loc, op->getName().getStringRef(),
+                                           operands, resultTypes, {});
+  return builder.createOperation(*res);
+}
+
+// Helper function for Tablegen.
+static bool hasShape(Value *v, ArrayRef<int64_t> shape) {
+  auto t = v->getType().dyn_cast<ShapedType>();
+  if (!t)
+    return false;
+  return std::equal(t.getShape().begin(), t.getShape().end(), shape.begin());
+}
+
+// Entry point for unrolling declarative pattern rewrites.
+// `op` is unrolled to the `targetShape` as follows, for each of its operands:
+//   1. the unrolled type `unrolledVectorType` and number of unrolled instances
+//   `numUnrolledInstances` are computed from the `targetShape`. For now it is
+//   assumed the unrolling factors divide the vector sizes.
+//   2. a fakeFork cast op is inserted that takes the operand and returns
+//   `numUnrolledInstances` results of type `unrolledVectorType`.
+//   3. the original op is cloned `numUnrolledInstances` times, once for each
+//   result of the fakeFork cast op.
+//   4. a fakeJoin cast op takes all these results and merges them into a single
+//   aggregate vector result whose size matches the original non-unrolled op
+//   operand types.
+//
+// Example:
+//
+//    opA(operand0, operand1)  // numUnrolledInstances = 3
+//
+//            operand0                   operand1
+//               |                          |
+//             fork                       fork
+//        <----------gather all fork ops --------->
+//              /|\                        /|\
+//          f00 f01 f02                f10 f11 f12
+//        <---------- clone op 3 times --------->
+//          opA0(f00, f10), opA1(f01, f11), opA2(f02, f12)
+//                 \            |            /
+//      <-------------------- join ------------------------->
+//
+// Other local patterns then kick in iteratively (including DCE) and compose
+// until all the fakeFork and fakeJoin ops are removed.
+//
+// This will be extended in the future to support more advanced use cases than
+// simple pointwise ops.
+static Value *unrollSingleResultOpMatchingType(PatternRewriter &builder,
+                                               Operation *op,
+                                               ArrayRef<int64_t> targetShape) {
+  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
+                       "]: unrollSingleResultOpMatchingType on func:\n");
+  LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
+  if (!op->getNumResults())
+    assert(false && "Use precondition till RewriterGen can act on nullptr");
+
+  auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
+  if (!shapedType || !shapedType.hasStaticShape())
+    assert(false && "Use precondition till RewriterGen can act on nullptr");
+
+  auto shape = shapedType.getShape();
+  auto maybeUnrollFactors = shapeRatio(shape, targetShape);
+  if (!maybeUnrollFactors.hasValue())
+    assert(false && "Use precondition till RewriterGen can act on nullptr");
+  auto unrollFactors = *maybeUnrollFactors;
+
+  auto loc = op->getLoc();
+  auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
+  auto unrolledVectorType =
+      VectorType::get(targetShape, shapedType.getElementType());
+  SmallVector<Type, 4> forkedType(numUnrolledInstances, unrolledVectorType);
+  SmallVector<Operation *, 4> forkeds;
+  forkeds.reserve(numUnrolledInstances);
+  // Create a new forkOp for each operand.
+  for (auto *operand : op->getOperands())
+    forkeds.push_back(
+        createFakeForkOp(builder, loc, operand, forkedType, unrollFactors));
+
+  SmallVector<Operation *, 4> newOps;
+  newOps.reserve(numUnrolledInstances);
+  for (int64_t idx = 0; idx < numUnrolledInstances; ++idx) {
+    SmallVector<Value *, 4> operands;
+    operands.reserve(forkeds.size());
+    for (auto *fork : forkeds) {
+      operands.push_back(fork->getResult(idx));
+    }
+    newOps.push_back(cloneOpWithOperandsAndTypes(builder, loc, op, operands,
+                                                 unrolledVectorType));
+  }
+
+  SmallVector<Value *, 4> newOpResults;
+  newOpResults.reserve(newOps.size());
+  for (auto *newOp : newOps)
+    newOpResults.push_back(newOp->getResult(0));
+
+  return createFakeJoinOp(builder, loc, newOpResults, shapedType, unrollFactors,
+                          {0})
+      ->getResult(0);
+}
+
+// Patterns with this benefit just forwards arguments to clean up fake fork and
+// fake joins. It is a nicer and more direct cleanup when we can use it so it
+// kicks in with higher precedence.
+static constexpr int64_t kMatchingFakeForkFakeJoinBenefit = 2;
+// Patterns with this benefit extract subvectors with ExtractElementOp and join
+// them to allow creating subvectors.
+static constexpr int64_t kFakeForkFromBlockArgBenefit = 1;
+
+namespace mlir {
+namespace vector {
+namespace {
+#include "mlir/Dialect/VectorOps/VectorTransformPatterns.h.inc"
+} // end namespace
+} // end namespace vector
+} // end namespace mlir
+
+// Match a fakeFork fed by a fakeJoin and just forward its operands.
+// This is akin to calling `replaceAllUsesOf` but made to play nice with all the
+// other RewritePattern.
+struct ConvertMatchingFakeForkFakeJoinOp : public RewritePattern {
+  ConvertMatchingFakeForkFakeJoinOp(MLIRContext *context)
+      // low-benefit to kick-in late
+      : RewritePattern(kFakeForkOp, kMatchingFakeForkFakeJoinBenefit, context) {
+  }
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override {
+    if (op->getNumOperands() != 1)
+      return matchFailure();
+
+    auto *definingOp = op->getOperand(0)->getDefiningOp();
+    if (!definingOp || definingOp->getName().getStringRef() != kFakeJoinOp)
+      return matchFailure();
+
+    if (definingOp->getNumOperands() != op->getNumResults())
+      return matchFailure();
+
+    for (auto it : llvm::zip(definingOp->getOperands(), op->getResults())) {
+      if (std::get<0>(it)->getType() != std::get<1>(it)->getType())
+        return matchFailure();
+    }
+
+    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
+                         "]: ConvertMatchingFakeForkFakeJoinOp on op: "
+                      << *op << " in func:\n");
+    LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
+    SmallVector<Value *, 4> forwardedOperands;
+    forwardedOperands.append(definingOp->getOperands().begin(),
+                             definingOp->getOperands().end());
+    rewriter.replaceOp(op, forwardedOperands);
+    return matchSuccess();
+  }
+};
+
+// Rewrites a fakeFork, whose (unique) operand is a blockArgument, into multiple
+// vector.strided_slice ops.
+struct ConvertFakeForkFromBlockArgsOp : public RewritePattern {
+  ConvertFakeForkFromBlockArgsOp(MLIRContext *context)
+      // low-benefit to kick-in late
+      : RewritePattern(kFakeForkOp, kFakeForkFromBlockArgBenefit, context) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override {
+    if (op->getNumOperands() != 1)
+      return matchFailure();
+
+    auto *blockArg = op->getOperand(0);
+    if (!isa<BlockArgument>(blockArg))
+      return matchFailure();
+
+    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
+                         "]: ConvertFakeForkFromBlockArgsOp on op: "
+                      << *op << " in func:\n");
+    LLVM_DEBUG(op->getParentOfType<FuncOp>().print(dbgs()));
+
+    // Look at the unroll factors remaining on this op and act on the first one.
+    auto unrollFactorsStorage = extractUnrollFactors(op);
+    ArrayRef<int64_t> unrollFactors{unrollFactorsStorage};
+    if (unrollFactors.empty()) {
+      // No more unrollFactors, just sanity check + forward the unique operand.
+      assert(op->getNumResults() == 1);
+      assert(op->getOperand(0)->getType() == op->getResult(0)->getType());
+      rewriter.replaceOp(op, op->getOperand(0));
+      return matchSuccess();
+    }
+
+    // Strides are always 1 for now.
+    // TODO(b/144845578) support non-1 strides.
+    auto forkedVectorType = op->getOperand(0)->getType().cast<VectorType>();
+    SmallVector<int64_t, 4> strides(unrollFactors.size(), 1);
+    auto nUnrolled = computeMaxLinearIndex(unrollFactors);
+    SmallVector<Value *, 4> extractedVectors;
+    extractedVectors.reserve(op->getNumResults());
+    auto linearizationBasis = computeStrides(unrollFactors);
+    for (unsigned idx = 0; idx < nUnrolled; ++idx) {
+      auto offsets = delinearize(idx, linearizationBasis);
+      offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, offsets,
+                       unrollFactors);
+      auto leadingSize =
+          forkedVectorType.getShape().take_front(unrollFactors.size());
+      auto sizes = zipMap([](int64_t v1, int64_t v2) { return v1 / v2; },
+                          leadingSize, unrollFactors);
+      extractedVectors.push_back(
+          rewriter
+              .create<vector::VectorStridedSliceOp>(op->getLoc(), blockArg,
+                                                    offsets, sizes, strides)
+              .getResult());
+    }
+    rewriter.replaceOp(op, extractedVectors);
+    return matchSuccess();
+  }
+};
+
+// Simple DCE for fakeForkOps/fakeJoinOps, we do not want them to escape a
+// transformation (otherwise the transformation is considered incorrect).
+struct FakeForkTrait {
+  static constexpr char const *name = kFakeForkOp;
+};
+struct FakeJoinTrait {
+  static constexpr char const *name = kFakeJoinOp;
+};
+
+template <typename OpNameTrait> struct DCEPattern : public RewritePattern {
+  DCEPattern(MLIRContext *context)
+      // low-benefit to kick-in late
+      : RewritePattern(OpNameTrait::name, 0, context) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override {
+    if (!op->use_empty())
+      return matchFailure();
+    rewriter.eraseOp(op);
+    return matchSuccess();
+  }
+};
+
+void mlir::populateVectorToVectorConversionPatterns(
+    MLIRContext *context, OwningRewritePatternList &patterns,
+    ArrayRef<int64_t> coarseVectorShape, ArrayRef<int64_t> fineVectorShape) {
+  vector::populateWithGenerated(context, &patterns);
+  patterns
+      .insert<ConvertMatchingFakeForkFakeJoinOp, ConvertFakeForkFromBlockArgsOp,
+              DCEPattern<FakeForkTrait>, DCEPattern<FakeJoinTrait>>(context);
+}
index 06016da..3034b71 100644 (file)
@@ -181,7 +181,7 @@ struct MaterializationState {
   SmallVector<int64_t, 8> hwVectorSize;
   VectorType superVectorType;
   VectorType hwVectorType;
-  SmallVector<unsigned, 8> hwVectorInstance;
+  SmallVector<int64_t, 8> hwVectorInstance;
   DenseMap<Value *, Value *> *substitutionsMap;
 };
 
@@ -206,24 +206,24 @@ struct MaterializeVectorsPass : public FunctionPass<MaterializeVectorsPass> {
 /// returns the distance, in number of elements, between a slice in a dimension
 /// and the next slice in the same dimension.
 ///   e.g. shape[3, 4, 5] -> strides[20, 5, 1]
-static SmallVector<unsigned, 8> makeStrides(ArrayRef<unsigned> shape) {
-  SmallVector<unsigned, 8> tmp;
+static SmallVector<int64_t, 8> makeStrides(ArrayRef<int64_t> shape) {
+  SmallVector<int64_t, 8> tmp;
   tmp.reserve(shape.size());
-  unsigned running = 1;
+  int64_t running = 1;
   for (auto rit = shape.rbegin(), reit = shape.rend(); rit != reit; ++rit) {
     assert(*rit > 0 && "size must be greater than 0 along all dimensions of "
                        "shape");
     tmp.push_back(running);
     running *= *rit;
   }
-  return SmallVector<unsigned, 8>(tmp.rbegin(), tmp.rend());
+  return SmallVector<int64_t, 8>(tmp.rbegin(), tmp.rend());
 }
 
 /// Given a shape with sizes greater than 0 along all dimensions, returns the
 /// delinearized components of linearIndex along shape.
-static SmallVector<unsigned, 8> delinearize(unsigned linearIndex,
-                                            ArrayRef<unsigned> shape) {
-  SmallVector<unsigned, 8> res;
+static SmallVector<int64_t, 8> delinearize(int64_t linearIndex,
+                                           ArrayRef<int64_t> shape) {
+  SmallVector<int64_t, 8> res;
   res.reserve(shape.size());
   auto strides = makeStrides(shape);
   for (unsigned idx = 0; idx < strides.size(); ++idx) {
@@ -333,7 +333,7 @@ static Value *substitute(Value *v, VectorType hwVectorType,
 /// vectorization trait at the op level directly.
 static SmallVector<mlir::Value *, 8>
 reindexAffineIndices(OpBuilder b, VectorType hwVectorType,
-                     ArrayRef<unsigned> hwVectorInstance,
+                     ArrayRef<int64_t> hwVectorInstance,
                      ArrayRef<Value *> memrefIndices) {
   auto vectorShape = hwVectorType.getShape();
   assert(hwVectorInstance.size() >= vectorShape.size());
@@ -483,7 +483,7 @@ static AffineMap projectedPermutationMap(VectorTransferOpTy transfer,
 /// reindexAffineIndices.
 static Operation *instantiate(OpBuilder b, VectorTransferReadOp read,
                               VectorType hwVectorType,
-                              ArrayRef<unsigned> hwVectorInstance,
+                              ArrayRef<int64_t> hwVectorInstance,
                               DenseMap<Value *, Value *> *substitutionsMap) {
   SmallVector<Value *, 8> indices =
       map(makePtrDynCaster<Value>(), read.indices());
@@ -507,7 +507,7 @@ static Operation *instantiate(OpBuilder b, VectorTransferReadOp read,
 /// reindexAffineIndices.
 static Operation *instantiate(OpBuilder b, VectorTransferWriteOp write,
                               VectorType hwVectorType,
-                              ArrayRef<unsigned> hwVectorInstance,
+                              ArrayRef<int64_t> hwVectorInstance,
                               DenseMap<Value *, Value *> *substitutionsMap) {
   SmallVector<Value *, 8> indices =
       map(makePtrDynCaster<Value>(), write.indices());
diff --git a/mlir/test/Conversion/VectorConversions/vector-to-vector.mlir b/mlir/test/Conversion/VectorConversions/vector-to-vector.mlir
new file mode 100644 (file)
index 0000000..98645b5
--- /dev/null
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
+
+// CHECK-LABEL: func @add4x2
+//      CHECK: %[[V1:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[V2:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[V3:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[V4:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
+// CHECK-NEXT: %[[V5:.*]] = addf %[[V1]], %[[V3]] : vector<2x2xf32>
+// CHECK-NEXT: %[[V6:.*]] = addf %[[V2]], %[[V4]] : vector<2x2xf32>
+// CHECK-NEXT: "__fake_join__"(%[[V5]], %[[V6]]) {__base_coord__ = [0], __unroll__ = [2, 1]} : (vector<2x2xf32>, vector<2x2xf32>) -> vector<4x2xf32>
+func @add4x2(%0: vector<4x2xf32>) -> vector<4x2xf32> {
+  %1 = addf %0, %0: vector<4x2xf32>
+  return %1: vector<4x2xf32>
+}
+
+// CHECK-LABEL: func @add4x4
+//      CHECK: vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: addf %{{.*}}, %{{.*}} : vector<2x2xf32>
+// CHECK-NEXT: addf %{{.*}}, %{{.*}} : vector<2x2xf32>
+// CHECK-NEXT: addf %{{.*}}, %{{.*}} : vector<2x2xf32>
+// CHECK-NEXT: addf %{{.*}}, %{{.*}} : vector<2x2xf32>
+// CHECK-NEXT: vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
+// CHECK-NEXT: addf %{{.*}}, %{{.*}} : vector<2x2xf32>
+// CHECK-NEXT: addf %{{.*}}, %{{.*}} : vector<2x2xf32>
+// CHECK-NEXT: addf %{{.*}}, %{{.*}} : vector<2x2xf32>
+// CHECK-NEXT: addf %{{.*}}, %{{.*}} : vector<2x2xf32>
+// CHECK-NEXT: "__fake_join__"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {__base_coord__ = [0], __unroll__ = [2, 2]} : (vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>) -> vector<4x4xf32>
+func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
+  %2 = addf %0, %1: vector<4x4xf32>
+  %3 = addf %1, %2: vector<4x4xf32>
+  return %3: vector<4x4xf32>
+}
index 675b695..788b909 100644 (file)
@@ -8,6 +8,7 @@ add_llvm_library(MLIRTestTransforms
   TestLowerVectorTransfers.cpp
   TestOpaqueLoc.cpp
   TestMemRefStrideCalculation.cpp
+  TestVectorToVectorConversion.cpp
   TestVectorizationUtils.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp b/mlir/test/lib/Transforms/TestVectorToVectorConversion.cpp
new file mode 100644 (file)
index 0000000..2550796
--- /dev/null
@@ -0,0 +1,44 @@
+//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering
+//-------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include <type_traits>
+
+#include "mlir/Conversion/VectorConversions/VectorConversions.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+
+namespace {
+
+struct TestVectorToVectorConversion
+    : public FunctionPass<TestVectorToVectorConversion> {
+  void runOnFunction() override {
+    OwningRewritePatternList patterns;
+    auto *context = &getContext();
+    populateVectorToVectorConversionPatterns(context, patterns);
+    applyPatternsGreedily(getFunction(), patterns);
+  }
+};
+
+} // end anonymous namespace
+
+static PassRegistration<TestVectorToVectorConversion>
+    pass("test-vector-to-vector-conversion",
+         "Test conversion patterns between ops in the vector dialect");
index 4fdb660..f0f1f6b 100644 (file)
@@ -131,7 +131,7 @@ void VectorizerTestPass::testVectorShapeRatio(llvm::raw_ostream &outs) {
       opInst->emitRemark("NOT MATCHED");
     } else {
       outs << "\nmatched: " << *opInst << " with shape ratio: ";
-      interleaveComma(MutableArrayRef<unsigned>(*ratio), outs);
+      interleaveComma(MutableArrayRef<int64_t>(*ratio), outs);
     }
   }
 }
index 628557d..9352be0 100644 (file)
@@ -49,7 +49,7 @@ set(LIBS
   MLIRTestTransforms
   MLIRSupport
   MLIRVectorOps
-  MLIRVectorToLLVM
+  MLIRVectorConversions
 )
 if(MLIR_CUDA_CONVERSIONS_ENABLED)
   list(APPEND LIBS