[mlir] VectorToSCF cleanup
authorMatthias Springer <springerm@google.com>
Fri, 14 May 2021 01:45:13 +0000 (10:45 +0900)
committerMatthias Springer <springerm@google.com>
Fri, 14 May 2021 02:04:37 +0000 (11:04 +0900)
Group functions/structs in namespaces for better code readability.

Depends On D102123

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D102124

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

index a209bc4..9972bcf 100644 (file)
@@ -49,52 +49,6 @@ struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
   VectorTransferToSCFOptions options;
 };
 
-/// Given a MemRefType with VectorType element type, unpack one dimension from
-/// the VectorType into the MemRefType.
-///
-/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
-static MemRefType unpackOneDim(MemRefType type) {
-  auto vectorType = type.getElementType().dyn_cast<VectorType>();
-  auto memrefShape = type.getShape();
-  SmallVector<int64_t, 8> newMemrefShape;
-  newMemrefShape.append(memrefShape.begin(), memrefShape.end());
-  newMemrefShape.push_back(vectorType.getDimSize(0));
-  return MemRefType::get(newMemrefShape,
-                         VectorType::get(vectorType.getShape().drop_front(),
-                                         vectorType.getElementType()));
-}
-
-/// Helper data structure for data and mask buffers.
-struct BufferAllocs {
-  Value dataBuffer;
-  Value maskBuffer;
-};
-
-/// Allocate temporary buffers for data (vector) and mask (if present).
-/// TODO: Parallelism and threadlocal considerations.
-template <typename OpTy>
-static BufferAllocs allocBuffers(OpTy xferOp) {
-  auto &b = ScopedContext::getBuilderRef();
-  OpBuilder::InsertionGuard guard(b);
-  Operation *scope =
-      xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
-  assert(scope && "Expected op to be inside automatic allocation scope");
-  b.setInsertionPointToStart(&scope->getRegion(0).front());
-
-  BufferAllocs result;
-  auto bufferType = MemRefType::get({}, xferOp.getVectorType());
-  result.dataBuffer = memref_alloca(bufferType).value;
-
-  if (xferOp.mask()) {
-    auto maskType = MemRefType::get({}, xferOp.mask().getType());
-    Value maskBuffer = memref_alloca(maskType);
-    memref_store(xferOp.mask(), maskBuffer);
-    result.maskBuffer = memref_load(maskBuffer);
-  }
-
-  return result;
-}
-
 /// Given a vector transfer op, calculate which dimension of the `source`
 /// memref should be unpacked in the next application of TransferOpConversion.
 /// A return value of None indicates a broadcast.
@@ -284,6 +238,54 @@ static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp,
     newXferOp->setAttr(kPassLabel, builder.getUnitAttr());
 }
 
+namespace lowering_n_d {
+
+/// Helper data structure for data and mask buffers.
+struct BufferAllocs {
+  Value dataBuffer;
+  Value maskBuffer;
+};
+
+/// Allocate temporary buffers for data (vector) and mask (if present).
+/// TODO: Parallelism and threadlocal considerations.
+template <typename OpTy>
+static BufferAllocs allocBuffers(OpTy xferOp) {
+  auto &b = ScopedContext::getBuilderRef();
+  OpBuilder::InsertionGuard guard(b);
+  Operation *scope =
+      xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
+  assert(scope && "Expected op to be inside automatic allocation scope");
+  b.setInsertionPointToStart(&scope->getRegion(0).front());
+
+  BufferAllocs result;
+  auto bufferType = MemRefType::get({}, xferOp.getVectorType());
+  result.dataBuffer = memref_alloca(bufferType).value;
+
+  if (xferOp.mask()) {
+    auto maskType = MemRefType::get({}, xferOp.mask().getType());
+    auto maskBuffer = memref_alloca(maskType).value;
+    memref_store(xferOp.mask(), maskBuffer);
+    result.maskBuffer = memref_load(maskBuffer);
+  }
+
+  return result;
+}
+
+/// Given a MemRefType with VectorType element type, unpack one dimension from
+/// the VectorType into the MemRefType.
+///
+/// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
+static MemRefType unpackOneDim(MemRefType type) {
+  auto vectorType = type.getElementType().dyn_cast<VectorType>();
+  auto memrefShape = type.getShape();
+  SmallVector<int64_t, 8> newMemrefShape;
+  newMemrefShape.append(memrefShape.begin(), memrefShape.end());
+  newMemrefShape.push_back(vectorType.getDimSize(0));
+  return MemRefType::get(newMemrefShape,
+                         VectorType::get(vectorType.getShape().drop_front(),
+                                         vectorType.getElementType()));
+}
+
 /// Given a transfer op, find the memref from which the mask is loaded. This
 /// is similar to Strategy<TransferWriteOp>::getBuffer.
 template <typename OpTy>
@@ -688,6 +690,10 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
   }
 };
 
+} // namespace lowering_n_d
+
+namespace lowering_n_d_unrolled {
+
 /// If the original transfer op has a mask, compute the mask of the new transfer
 /// op (for the current iteration `i`) and assign it.
 template <typename OpTy>
@@ -954,6 +960,10 @@ struct UnrollTransferWriteConversion
   }
 };
 
+} // namespace lowering_n_d_unrolled
+
+namespace lowering_1_d {
+
 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
 /// part of TransferOp1dConversion. Return the memref dimension on which
 /// the transfer is operating. A return value of None indicates a broadcast.
@@ -1114,6 +1124,7 @@ struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
   }
 };
 
+} // namespace lowering_1_d
 } // namespace
 
 namespace mlir {
@@ -1121,19 +1132,21 @@ namespace mlir {
 void populateVectorToSCFConversionPatterns(
     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
   if (options.unroll) {
-    patterns.add<UnrollTransferReadConversion, UnrollTransferWriteConversion>(
+    patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
+                 lowering_n_d_unrolled::UnrollTransferWriteConversion>(
         patterns.getContext(), options);
   } else {
-    patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
-                 TransferOpConversion<TransferReadOp>,
-                 TransferOpConversion<TransferWriteOp>>(patterns.getContext(),
-                                                        options);
+    patterns.add<lowering_n_d::PrepareTransferReadConversion,
+                 lowering_n_d::PrepareTransferWriteConversion,
+                 lowering_n_d::TransferOpConversion<TransferReadOp>,
+                 lowering_n_d::TransferOpConversion<TransferWriteOp>>(
+        patterns.getContext(), options);
   }
 
   if (options.targetRank == 1) {
-    patterns.add<TransferOp1dConversion<TransferReadOp>,
-                 TransferOp1dConversion<TransferWriteOp>>(patterns.getContext(),
-                                                          options);
+    patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
+                 lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
+        patterns.getContext(), options);
   }
 }