From: Nicolas Vasilache Date: Fri, 29 Mar 2019 16:47:30 +0000 (-0700) Subject: Make createMaterializeVectorsPass take a vectorSize parameter - NFC X-Git-Tag: llvmorg-11-init~1466^2~2076 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f93a5be65f1a69093a51252b94c8a0c93cb546b4;p=platform%2Fupstream%2Fllvm.git Make createMaterializeVectorsPass take a vectorSize parameter - NFC This CL allows the programmatic control of the target hardware vector size when creating a MaterializeVectorsPass. This is useful for registering passes for the tutorial. PiperOrigin-RevId: 240996136 --- diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 52b3a0f..bbfd43a 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -52,7 +52,8 @@ createVectorizePass(llvm::ArrayRef virtualVectorSize); FunctionPassBase *createVectorizerTestPass(); /// Creates a pass to lower super-vectors to target-dependent HW vectors. -FunctionPassBase *createMaterializeVectorsPass(); +FunctionPassBase * +createMaterializeVectorsPass(llvm::ArrayRef vectorSize); /// Creates a loop unrolling pass with the provided parameters. /// 'getUnrollFactor' is a function callback for clients to supply a function diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 7e4a459..1810956 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -185,9 +185,8 @@ struct MaterializationState { /// of the type and we assert everything is f32. /// TODO(ntv): relax the assumptions on admissible element type once a /// contract exists. - MaterializationState() : hwVectorSize(clVectorSize.size(), 0) { - std::copy(clVectorSize.begin(), clVectorSize.end(), hwVectorSize.begin()); - } + MaterializationState(SmallVector sizes) : hwVectorSize(sizes) {} + SmallVector hwVectorSize; VectorType superVectorType; VectorType hwVectorType; @@ -195,7 +194,18 @@ struct MaterializationState { DenseMap *substitutionsMap; }; +/// Base state for the vector materialization pass. +/// Command line arguments are preempted by non-empty pass arguments. struct MaterializeVectorsPass : public FunctionPass { + MaterializeVectorsPass() + : hwVectorSize(clVectorSize.begin(), clVectorSize.end()) {} + MaterializeVectorsPass(ArrayRef hwVectorSize) + : MaterializeVectorsPass() { + if (!hwVectorSize.empty()) + this->hwVectorSize.assign(hwVectorSize.begin(), hwVectorSize.end()); + } + + SmallVector hwVectorSize; void runOnFunction() override; }; @@ -739,11 +749,11 @@ void MaterializeVectorsPass::runOnFunction() { LLVM_DEBUG(dbgs() << "\nMaterializeVectors on Function\n"); LLVM_DEBUG(f->print(dbgs())); - MaterializationState state; + MaterializationState state(hwVectorSize); // Get the hardware vector type. // TODO(ntv): get elemental type from super-vector type rather than force f32. auto subVectorType = - VectorType::get(state.hwVectorSize, FloatType::getF32(&getContext())); + VectorType::get(hwVectorSize, FloatType::getF32(&getContext())); // Capture terminators; i.e. vector_transfer_write ops involving a strict // super-vector of subVectorType. @@ -765,8 +775,9 @@ void MaterializeVectorsPass::runOnFunction() { signalPassFailure(); } -FunctionPassBase *mlir::createMaterializeVectorsPass() { - return new MaterializeVectorsPass(); +FunctionPassBase * +mlir::createMaterializeVectorsPass(llvm::ArrayRef vectorSize) { + return new MaterializeVectorsPass(vectorSize); } static PassRegistration diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index 0874532..1ae2e04 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -620,12 +620,10 @@ struct Vectorize : public FunctionPass { } // end anonymous namespace -Vectorize::Vectorize() { - this->vectorSizes.assign(clVirtualVectorSize.begin(), - clVirtualVectorSize.end()); - this->fastestVaryingPattern.assign(clFastestVaryingPattern.begin(), - clFastestVaryingPattern.end()); -} +Vectorize::Vectorize() + : vectorSizes(clVirtualVectorSize.begin(), clVirtualVectorSize.end()), + fastestVaryingPattern(clFastestVaryingPattern.begin(), + clFastestVaryingPattern.end()) {} Vectorize::Vectorize(ArrayRef virtualVectorSize) : Vectorize() { if (!virtualVectorSize.empty()) {