Make createMaterializeVectorsPass take a vectorSize parameter - NFC
authorNicolas Vasilache <ntv@google.com>
Fri, 29 Mar 2019 16:47:30 +0000 (09:47 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 30 Mar 2019 00:56:12 +0000 (17:56 -0700)
This CL allows the programmatic control of the target hardware vector size when creating a MaterializeVectorsPass.
This is useful for registering passes for the tutorial.

PiperOrigin-RevId: 240996136

mlir/include/mlir/Transforms/Passes.h
mlir/lib/Transforms/MaterializeVectors.cpp
mlir/lib/Transforms/Vectorize.cpp

index 52b3a0f..bbfd43a 100644 (file)
@@ -52,7 +52,8 @@ createVectorizePass(llvm::ArrayRef<int64_t> virtualVectorSize);
 FunctionPassBase *createVectorizerTestPass();
 
 /// Creates a pass to lower super-vectors to target-dependent HW vectors.
-FunctionPassBase *createMaterializeVectorsPass();
+FunctionPassBase *
+createMaterializeVectorsPass(llvm::ArrayRef<int64_t> vectorSize);
 
 /// Creates a loop unrolling pass with the provided parameters.
 /// 'getUnrollFactor' is a function callback for clients to supply a function
index 7e4a459..1810956 100644 (file)
@@ -185,9 +185,8 @@ struct MaterializationState {
   /// of the type and we assert everything is f32.
   /// TODO(ntv): relax the assumptions on admissible element type once a
   /// contract exists.
-  MaterializationState() : hwVectorSize(clVectorSize.size(), 0) {
-    std::copy(clVectorSize.begin(), clVectorSize.end(), hwVectorSize.begin());
-  }
+  MaterializationState(SmallVector<int64_t, 8> sizes) : hwVectorSize(sizes) {}
+
   SmallVector<int64_t, 8> hwVectorSize;
   VectorType superVectorType;
   VectorType hwVectorType;
@@ -195,7 +194,18 @@ struct MaterializationState {
   DenseMap<Value *, Value *> *substitutionsMap;
 };
 
+/// Base state for the vector materialization pass.
+/// Command line arguments are preempted by non-empty pass arguments.
 struct MaterializeVectorsPass : public FunctionPass<MaterializeVectorsPass> {
+  MaterializeVectorsPass()
+      : hwVectorSize(clVectorSize.begin(), clVectorSize.end()) {}
+  MaterializeVectorsPass(ArrayRef<int64_t> hwVectorSize)
+      : MaterializeVectorsPass() {
+    if (!hwVectorSize.empty())
+      this->hwVectorSize.assign(hwVectorSize.begin(), hwVectorSize.end());
+  }
+
+  SmallVector<int64_t, 8> hwVectorSize;
   void runOnFunction() override;
 };
 
@@ -739,11 +749,11 @@ void MaterializeVectorsPass::runOnFunction() {
   LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n");
   LLVM_DEBUG(f->print(dbgs()));
 
-  MaterializationState state;
+  MaterializationState state(hwVectorSize);
   // Get the hardware vector type.
   // TODO(ntv): get elemental type from super-vector type rather than force f32.
   auto subVectorType =
-      VectorType::get(state.hwVectorSize, FloatType::getF32(&getContext()));
+      VectorType::get(hwVectorSize, FloatType::getF32(&getContext()));
 
   // Capture terminators; i.e. vector_transfer_write ops involving a strict
   // super-vector of subVectorType.
@@ -765,8 +775,9 @@ void MaterializeVectorsPass::runOnFunction() {
     signalPassFailure();
 }
 
-FunctionPassBase *mlir::createMaterializeVectorsPass() {
-  return new MaterializeVectorsPass();
+FunctionPassBase *
+mlir::createMaterializeVectorsPass(llvm::ArrayRef<int64_t> vectorSize) {
+  return new MaterializeVectorsPass(vectorSize);
 }
 
 static PassRegistration<MaterializeVectorsPass>
index 0874532..1ae2e04 100644 (file)
@@ -620,12 +620,10 @@ struct Vectorize : public FunctionPass<Vectorize> {
 
 } // end anonymous namespace
 
-Vectorize::Vectorize() {
-  this->vectorSizes.assign(clVirtualVectorSize.begin(),
-                           clVirtualVectorSize.end());
-  this->fastestVaryingPattern.assign(clFastestVaryingPattern.begin(),
-                                     clFastestVaryingPattern.end());
-}
+Vectorize::Vectorize()
+    : vectorSizes(clVirtualVectorSize.begin(), clVirtualVectorSize.end()),
+      fastestVaryingPattern(clFastestVaryingPattern.begin(),
+                            clFastestVaryingPattern.end()) {}
 
 Vectorize::Vectorize(ArrayRef<int64_t> virtualVectorSize) : Vectorize() {
   if (!virtualVectorSize.empty()) {