[MLIR][Vector] Add support for TupleGetOp folding through InsertSlicesOp and ExtractS...
authorAndy Davis <andydavis@google.com>
Tue, 31 Mar 2020 15:21:04 +0000 (08:21 -0700)
committerAndy Davis <andydavis@google.com>
Tue, 31 Mar 2020 15:39:17 +0000 (08:39 -0700)
Summary:
Add support for TupleGetOp folding through InsertSlicesOp and ExtractSlicesOp.
Vector-to-vector transformations for unrolling and lowering to hardware vectors
can generate chains of structured vector operations (InsertSlicesOp,
ExtractSlicesOp and ShapeCastOp) between the producer of a hardware vector
value and its consumer. Because InsertSlicesOp, ExtractSlicesOp and ShapeCastOp
are structured, we can track the location (tuple index and vector offsets) of
the consumer vector value through the chain of structured operations to the
producer, enabling a much more powerful producer-consumer fowarding of values
through structured ops and tuple, which in turn enables a more powerful
TupleGetOp folding transformation.

Reviewers: nicolasvasilache, aartbik

Reviewed By: aartbik

Subscribers: grosul1, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D76889

mlir/include/mlir/Dialect/Vector/VectorUtils.h
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/Dialect/Vector/VectorUtils.cpp
mlir/test/Dialect/Vector/vector-transforms.mlir

index 4bc03e4..3552754 100644 (file)
@@ -31,6 +31,9 @@ class VectorType;
 SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
                                        ArrayRef<int64_t> sizes);
 
+/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
+int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis);
+
 /// Given the slice strides together with a linear index in the dimension
 /// space, returns the vector-space offsets in each dimension for a
 /// de-linearized index.
index 2716d62..dbb0bf4 100644 (file)
@@ -69,15 +69,6 @@ static int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis) {
   return res;
 }
 
-/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'.
-static int64_t linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
-  assert(offsets.size() == basis.size());
-  int64_t linearIndex = 0;
-  for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
-    linearIndex += offsets[idx] * basis[idx];
-  return linearIndex;
-}
-
 // Clones `op` into a new operations that takes `operands` and returns
 // `resultTypes`.
 static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder,
@@ -683,6 +674,99 @@ struct ShapeCastOpDecomposer : public OpRewritePattern<vector::ShapeCastOp> {
   }
 };
 
+/// Returns the producer Value of the same type as 'consumerValue', by tracking
+/// the tuple index and offsets of the consumer vector value through the
+/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp)
+/// from consumer to producer. Each operation in the chain is structured, and
+/// so the tuple index and offsets can be mapped from result to input, while
+/// visiting each operation in the chain.
+/// Returns nullptr on failure.
+static Value getProducerValue(Value consumerValue) {
+  auto consumerVectorType = consumerValue.getType().cast<VectorType>();
+  // A tupleIndex == -1 indicates that 'offsets' are w.r.t a vector type.
+  int64_t tupleIndex = -1;
+  SmallVector<int64_t, 4> offsets(consumerVectorType.getRank(), 0);
+  auto *op = consumerValue.getDefiningOp();
+  while (op != nullptr) {
+    if (auto tupleGetOp = dyn_cast<vector::TupleGetOp>(op)) {
+      assert(tupleIndex == -1 && "TupleGetOp must have vector result type");
+
+      // Update 'tupleIndex' and next defining 'op' to visit.
+      tupleIndex = tupleGetOp.getIndex();
+      op = tupleGetOp.vectors().getDefiningOp();
+    } else if (auto extractSlicesOp = dyn_cast<vector::ExtractSlicesOp>(op)) {
+      assert(tupleIndex >= 0);
+
+      // Compute slice strides for 'extractSlicesOp'.
+      SmallVector<int64_t, 4> sizes;
+      extractSlicesOp.getSizes(sizes);
+      auto sliceStrides = computeStrides(
+          extractSlicesOp.getSourceVectorType().getShape(), sizes);
+
+      // Compute 'elementOffsets' into 'extractSlicesOp' input vector type,
+      // of 'extractSlicesOp' result vector tuple element at 'tupleIndex'.
+      auto vectorOffsets = delinearize(sliceStrides, tupleIndex);
+      auto elementOffsets =
+          computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
+
+      // Add 'elementOffsets' to 'offsets' so that 'offsets' are now relative
+      // to the 'extractSlicesOp' input vector type.
+      assert(offsets.size() == elementOffsets.size());
+      for (unsigned i = 0, e = offsets.size(); i < e; ++i)
+        offsets[i] += elementOffsets[i];
+
+      // Clear 'tupleIndex' and update next defining 'op' to visit.
+      tupleIndex = -1;
+      op = extractSlicesOp.vector().getDefiningOp();
+    } else if (auto insertSlicesOp = dyn_cast<vector::InsertSlicesOp>(op)) {
+      assert(tupleIndex == -1);
+
+      // Compute slice strides for 'insertSlicesOp'.
+      SmallVector<int64_t, 4> sizes;
+      insertSlicesOp.getSizes(sizes);
+      auto sliceStrides = computeStrides(
+          insertSlicesOp.getResultVectorType().getShape(), sizes);
+
+      // Compute 'vectorOffsets' of 'insertSlicesOp' input vector slice,
+      // of 'insertSlicesOp' result vector type at 'offsets'.
+      SmallVector<int64_t, 4> vectorOffsets(offsets.size());
+      assert(offsets.size() == sizes.size());
+      for (unsigned i = 0, e = offsets.size(); i < e; ++i)
+        vectorOffsets[i] = offsets[i] / sizes[i];
+
+      // Compute the source tuple element index.
+      tupleIndex = linearize(vectorOffsets, sliceStrides);
+
+      // Subtract 'elementOffsets' from 'offsets' so that 'offsets' are now
+      // relative to input tuple element vector type at 'tupleIndex'.
+      auto elementOffsets =
+          computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets);
+      assert(offsets.size() == elementOffsets.size());
+      for (unsigned i = 0, e = offsets.size(); i < e; ++i) {
+        offsets[i] -= elementOffsets[i];
+        assert(offsets[i] >= 0);
+      }
+
+      // Update next defining 'op' to visit.
+      op = insertSlicesOp.vectors().getDefiningOp();
+    } else if (auto tupleOp = dyn_cast<vector::TupleOp>(op)) {
+      assert(tupleIndex >= 0);
+
+      // Return tuple element 'value' at 'tupleIndex' if it matches type.
+      auto value = tupleOp.getOperand(tupleIndex);
+      if (value.getType() == consumerVectorType)
+        return value;
+
+      // Update 'tupleIndex' and next defining 'op' to visit.
+      tupleIndex = -1;
+      op = value.getDefiningOp();
+    } else {
+      break;
+    }
+  }
+  return nullptr;
+}
+
 /// ShapeCastOpFolder folds cancelling ShapeCastOps away.
 //
 // Example:
@@ -740,28 +824,11 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
 
   LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp,
                                 PatternRewriter &rewriter) const override {
-    // Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp.
-    auto extractSlicesOp = dyn_cast_or_null<vector::ExtractSlicesOp>(
-        tupleGetOp.vectors().getDefiningOp());
-    if (!extractSlicesOp)
-      return failure();
-
-    // Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp.
-    auto insertSlicesOp = dyn_cast_or_null<vector::InsertSlicesOp>(
-        extractSlicesOp.vector().getDefiningOp());
-    if (!insertSlicesOp)
-      return failure();
-
-    // Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp.
-    auto tupleOp = dyn_cast_or_null<vector::TupleOp>(
-        insertSlicesOp.vectors().getDefiningOp());
-    if (!tupleOp)
-      return failure();
-
-    // Forward Value from 'tupleOp' at 'tupleGetOp.index'.
-    Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex());
-    rewriter.replaceOp(tupleGetOp, tupleValue);
-    return success();
+    if (auto producer = getProducerValue(tupleGetOp.getResult())) {
+      rewriter.replaceOp(tupleGetOp, producer);
+      return success();
+    }
+    return failure();
   }
 };
 
index f929ddd..9038b7a 100644 (file)
 
 using llvm::SetVector;
 
-namespace mlir {
+using namespace mlir;
 
-SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
-                                       ArrayRef<int64_t> sizes) {
+SmallVector<int64_t, 4> mlir::computeStrides(ArrayRef<int64_t> shape,
+                                             ArrayRef<int64_t> sizes) {
   int64_t rank = shape.size();
   // Compute the count for each dimension.
   SmallVector<int64_t, 4> sliceDimCounts(rank);
@@ -45,8 +45,16 @@ SmallVector<int64_t, 4> computeStrides(ArrayRef<int64_t> shape,
   return sliceStrides;
 }
 
-SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> sliceStrides,
-                                    int64_t index) {
+int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
+  assert(offsets.size() == basis.size());
+  int64_t linearIndex = 0;
+  for (unsigned idx = 0, e = basis.size(); idx < e; ++idx)
+    linearIndex += offsets[idx] * basis[idx];
+  return linearIndex;
+}
+
+SmallVector<int64_t, 4> mlir::delinearize(ArrayRef<int64_t> sliceStrides,
+                                          int64_t index) {
   int64_t rank = sliceStrides.size();
   SmallVector<int64_t, 4> vectorOffsets(rank);
   for (int64_t r = 0; r < rank; ++r) {
@@ -57,16 +65,15 @@ SmallVector<int64_t, 4> delinearize(ArrayRef<int64_t> sliceStrides,
   return vectorOffsets;
 }
 
-SmallVector<int64_t, 4>
-computeElementOffsetsFromVectorSliceOffsets(ArrayRef<int64_t> sizes,
-                                            ArrayRef<int64_t> vectorOffsets) {
+SmallVector<int64_t, 4> mlir::computeElementOffsetsFromVectorSliceOffsets(
+    ArrayRef<int64_t> sizes, ArrayRef<int64_t> vectorOffsets) {
   return functional::zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
                             vectorOffsets, sizes);
 }
 
-SmallVector<int64_t, 4> computeSliceSizes(ArrayRef<int64_t> shape,
-                                          ArrayRef<int64_t> sizes,
-                                          ArrayRef<int64_t> elementOffsets) {
+SmallVector<int64_t, 4>
+mlir::computeSliceSizes(ArrayRef<int64_t> shape, ArrayRef<int64_t> sizes,
+                        ArrayRef<int64_t> elementOffsets) {
   int64_t rank = shape.size();
   SmallVector<int64_t, 4> sliceSizes(rank);
   for (unsigned r = 0; r < rank; ++r)
@@ -74,8 +81,8 @@ SmallVector<int64_t, 4> computeSliceSizes(ArrayRef<int64_t> shape,
   return sliceSizes;
 }
 
-Optional<SmallVector<int64_t, 4>> shapeRatio(ArrayRef<int64_t> superShape,
-                                             ArrayRef<int64_t> subShape) {
+Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(ArrayRef<int64_t> superShape,
+                                                   ArrayRef<int64_t> subShape) {
   if (superShape.size() < subShape.size()) {
     return Optional<SmallVector<int64_t, 4>>();
   }
@@ -114,8 +121,8 @@ Optional<SmallVector<int64_t, 4>> shapeRatio(ArrayRef<int64_t> superShape,
   return SmallVector<int64_t, 4>{result.rbegin(), result.rend()};
 }
 
-Optional<SmallVector<int64_t, 4>> shapeRatio(VectorType superVectorType,
-                                             VectorType subVectorType) {
+Optional<SmallVector<int64_t, 4>> mlir::shapeRatio(VectorType superVectorType,
+                                                   VectorType subVectorType) {
   assert(superVectorType.getElementType() == subVectorType.getElementType() &&
          "vector types must be of the same elemental type");
   return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
@@ -201,9 +208,9 @@ static SetVector<Operation *> getEnclosingforOps(Operation *op) {
   return getParentsOfType<AffineForOp>(op);
 }
 
-AffineMap
-makePermutationMap(Operation *op, ArrayRef<Value> indices,
-                   const DenseMap<Operation *, unsigned> &loopToVectorDim) {
+AffineMap mlir::makePermutationMap(
+    Operation *op, ArrayRef<Value> indices,
+    const DenseMap<Operation *, unsigned> &loopToVectorDim) {
   DenseMap<Operation *, unsigned> enclosingLoopToVectorDim;
   auto enclosingLoops = getEnclosingforOps(op);
   for (auto *forInst : enclosingLoops) {
@@ -212,7 +219,7 @@ makePermutationMap(Operation *op, ArrayRef<Value> indices,
       enclosingLoopToVectorDim.insert(*it);
     }
   }
-  return makePermutationMap(indices, enclosingLoopToVectorDim);
+  return ::makePermutationMap(indices, enclosingLoopToVectorDim);
 }
 
 bool matcher::operatesOnSuperVectorsOf(Operation &op,
@@ -275,4 +282,3 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
   return true;
 }
 
-} // namespace mlir
index 7582758..082afba 100644 (file)
@@ -313,6 +313,95 @@ func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
   return %1 : vector<8xf32>
 }
 
+// CHECK-LABEL: func @tuple_get_producer_consumer
+// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32>
+//      CHECK: return %[[A7]] : vector<2x4xf32>
+
+func @tuple_get_producer_consumer(
+  %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>,
+  %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>,
+  %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>,
+  %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> {
+  %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7
+    : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
+      vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>
+  // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0]
+  %1 = vector.insert_slices %0, [2, 4], [1, 1]
+    : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
+            vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+      into vector<4x16xf32>
+  // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12]
+  %2 = vector.extract_slices %1, [4, 8], [1, 1]
+    : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
+  %3 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 == %3 at tupleIndex = -1, offsets = [2, 4]
+  %4 = vector.extract_slices %3, [2, 4], [1, 1]
+    : vector<4x8xf32> into
+      tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+  // %arg7 == %4 at tupleIndex = 3, offsets = [0, 0]
+  %5 = vector.tuple_get %4, 3
+    : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+  // %arg7 == %5
+  return %5 : vector<2x4xf32>
+}
+
+// CHECK-LABEL: func @tuple_get_producer_consumer_swizzle
+// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>,
+// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32>
+//      CHECK: return %[[A7]] : vector<2x4xf32>
+
+func @tuple_get_producer_consumer_swizzle(
+  %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>,
+  %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>,
+  %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>,
+  %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> {
+  %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7
+    : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
+      vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>
+  // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0]
+  %1 = vector.insert_slices %0, [2, 4], [1, 1]
+    : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
+            vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+      into vector<4x16xf32>
+  // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12]
+  %2 = vector.extract_slices %1, [4, 8], [1, 1]
+    : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
+
+  // Extract tuple elements.
+  %3 = vector.tuple_get %2, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+  %4 = vector.tuple_get %2, 1 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 == %4 at tupleIndex = -1, offsets = [2, 4]
+
+  // Swizzle tuple elements.
+  %5 = vector.tuple %4, %3 : vector<4x8xf32>, vector<4x8xf32>
+  // %arg7 == %5 at tupleIndex = 0, offsets = [2, 4]
+  %6 = vector.tuple_get %5, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
+  // %arg7 == %6 at tupleIndex = -1, offsets = [2, 4]
+  %7 = vector.extract_slices %6, [2, 4], [1, 1]
+    : vector<4x8xf32> into
+      tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+  // %arg7 == %7 at tupleIndex = 3, offsets = [0, 0]
+  %8 = vector.tuple_get %7, 3
+    : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
+  // %arg7 == %8
+  return %8 : vector<2x4xf32>
+}
+
 // CHECK-LABEL: func @vector_transfers_vector_element_type
 //      CHECK: %[[C0:.*]] = constant 0 : index
 //      CHECK: %[[C1:.*]] = constant 1 : index