From 2666b97314ad1b50f88fcc4376ae941f601f67ea Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 18 Dec 2019 10:46:16 -0800 Subject: [PATCH] NFC: Cleanup non-conforming usages of namespaces. * Fixes use of anonymous namespace for static methods. * Uses explicit qualifiers(mlir::) instead of wrapping the definition with the namespace. PiperOrigin-RevId: 286222654 --- .../Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp | 11 +++--- .../VectorToLoops/ConvertVectorToLoops.cpp | 15 ++++---- .../Dialect/QuantOps/Utils/FakeQuantSupport.cpp | 40 +++++++++----------- mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp | 20 +++++----- mlir/lib/Dialect/SDBM/SDBMExpr.cpp | 12 ++---- .../DecorateSPIRVCompositeTypeLayoutPass.cpp | 2 +- mlir/lib/ExecutionEngine/ExecutionEngine.cpp | 7 +--- mlir/lib/Quantizer/Support/Statistics.cpp | 9 +---- mlir/lib/Quantizer/Support/UniformSolvers.cpp | 14 +++---- mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp | 2 +- mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp | 2 +- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 43 ++++++++++------------ mlir/lib/Transforms/LoopFusion.cpp | 22 ++++++----- 13 files changed, 88 insertions(+), 111 deletions(-) diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp index 92cc026..42483a6 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -18,6 +18,7 @@ // This file implements the conversion patterns from GPU ops to SPIR-V dialect. // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" @@ -350,11 +351,10 @@ PatternMatchResult GPUReturnOpConversion::matchAndRewrite( // GPU To SPIRV Patterns. //===----------------------------------------------------------------------===// -namespace mlir { -void populateGPUToSPIRVPatterns(MLIRContext *context, - SPIRVTypeConverter &typeConverter, - OwningRewritePatternList &patterns, - ArrayRef workGroupSize) { +void mlir::populateGPUToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns, + ArrayRef workGroupSize) { patterns.insert(context, typeConverter, workGroupSize); patterns.insert< GPUReturnOpConversion, ForOpConversion, KernelModuleConversion, @@ -366,4 +366,3 @@ void populateGPUToSPIRVPatterns(MLIRContext *context, spirv::BuiltIn::LocalInvocationId>>(context, typeConverter); } -} // namespace mlir diff --git a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp index 721e709..0b39f60 100644 --- a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp +++ b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp @@ -117,14 +117,16 @@ struct VectorTransferRewriter : public RewritePattern { PatternRewriter &rewriter) const override; }; +} // namespace + /// Analyzes the `transfer` to find an access dimension along the fastest remote /// MemRef dimension. If such a dimension with coalescing properties is found, /// `pivs` and `vectorView` are swapped so that the invocation of /// LoopNestBuilder captures it in the innermost loop. template -void coalesceCopy(TransferOpTy transfer, - SmallVectorImpl *pivs, - edsc::VectorView *vectorView) { +static void coalesceCopy(TransferOpTy transfer, + SmallVectorImpl *pivs, + edsc::VectorView *vectorView) { // rank of the remote memory access, coalescing behavior occurs on the // innermost memory dimension. auto remoteRank = transfer.getMemRefType().getRank(); @@ -155,9 +157,9 @@ void coalesceCopy(TransferOpTy transfer, /// Emits remote memory accesses that are clipped to the boundaries of the /// MemRef. template -SmallVector clip(TransferOpTy transfer, - edsc::MemRefView &view, - ArrayRef ivs) { +static SmallVector clip(TransferOpTy transfer, + edsc::MemRefView &view, + ArrayRef ivs) { using namespace mlir::edsc; using namespace edsc::op; using edsc::intrinsics::select; @@ -357,7 +359,6 @@ PatternMatchResult VectorTransferRewriter::matchAndRewrite( rewriter.eraseOp(op); return matchSuccess(); } -} // namespace void mlir::populateVectorToAffineLoopsConversionPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp index 10668f8..f4256cf 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -18,12 +18,13 @@ #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" -namespace mlir { -namespace quant { -namespace { -bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned, - MLIRContext *ctx, Type &storageType, int64_t &qmin, - int64_t &qmax) { +using namespace mlir; +using namespace mlir::quant; + +static bool getDefaultStorageParams(unsigned numBits, bool narrowRange, + bool isSigned, MLIRContext *ctx, + Type &storageType, int64_t &qmin, + int64_t &qmax) { // Hard-coded type mapping from TFLite. if (numBits <= 8) { storageType = IntegerType::get(8, ctx); @@ -62,9 +63,9 @@ bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned, // range will be outside the shifted range and be clamped during quantization. // TODO(fengliuai): we should nudge the scale as well, but that requires the // fake quant op used in the training to use the nudged scale as well. -void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, - double rmax, double &scale, - int64_t &nudgedZeroPoint) { +static void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, + double rmax, double &scale, + int64_t &nudgedZeroPoint) { // Determine the scale. const double qminDouble = qmin; const double qmaxDouble = qmax; @@ -103,12 +104,10 @@ void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, assert(nudgedZeroPoint <= qmax); } -} // end namespace - -UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, - double rmin, double rmax, - bool narrowRange, Type expressedType, - bool isSigned) { +UniformQuantizedType +mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, + double rmax, bool narrowRange, + Type expressedType, bool isSigned) { MLIRContext *ctx = expressedType.getContext(); unsigned flags = isSigned ? QuantizationFlags::Signed : 0; Type storageType; @@ -137,10 +136,10 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, loc); } -UniformQuantizedPerAxisType -fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, - ArrayRef rmins, ArrayRef rmaxs, - bool narrowRange, Type expressedType, bool isSigned) { +UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType( + Location loc, unsigned numBits, int32_t quantizedDimension, + ArrayRef rmins, ArrayRef rmaxs, bool narrowRange, + Type expressedType, bool isSigned) { size_t axis_size = rmins.size(); if (axis_size != rmaxs.size()) { return (emitError(loc, "mismatched per-axis min and max size: ") @@ -183,6 +182,3 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, flags, storageType, expressedType, scales, zeroPoints, quantizedDimension, qmin, qmax, loc); } - -} // namespace quant -} // namespace mlir diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp index e7a1df9..56e2cba 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp @@ -20,8 +20,9 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/StandardTypes.h" -namespace mlir { -namespace quant { +using namespace mlir; +using namespace mlir::quant; + /// Converts a possible primitive, real expressed value attribute to a /// corresponding storage attribute (typically FloatAttr -> IntegerAttr). /// quantizedElementType is the QuantizedType that describes the expressed @@ -104,10 +105,9 @@ convertSparseElementsAttr(SparseElementsAttr realSparseAttr, /// Converts a real expressed Attribute to a corresponding Attribute containing /// quantized storage values assuming the given uniform quantizedElementType and /// converter. -Attribute quantizeAttrUniform(Attribute realValue, - UniformQuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter, - Type &outConvertedType) { +Attribute mlir::quant::quantizeAttrUniform( + Attribute realValue, UniformQuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, Type &outConvertedType) { // Fork to handle different variants of constants supported. if (realValue.isa()) { // Dense tensor or vector constant. @@ -133,8 +133,9 @@ Attribute quantizeAttrUniform(Attribute realValue, /// quantizedElementType.getStorageType(). /// Returns nullptr if the conversion is not supported. /// On success, stores the converted type in outConvertedType. -Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType, - Type &outConvertedType) { +Attribute mlir::quant::quantizeAttr(Attribute realValue, + QuantizedType quantizedElementType, + Type &outConvertedType) { if (auto uniformQuantized = quantizedElementType.dyn_cast()) { UniformQuantizedValueConverter converter(uniformQuantized); @@ -154,6 +155,3 @@ Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType, return nullptr; } } - -} // namespace quant -} // namespace mlir diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp index 8cdd9c8..44cdd18 100644 --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -671,10 +671,7 @@ SDBMDirectExpr SDBMNegExpr::getVar() const { return static_cast(impl)->expr; } -namespace mlir { -namespace ops_assertions { - -SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) { +SDBMExpr mlir::ops_assertions::operator+(SDBMExpr lhs, SDBMExpr rhs) { if (auto folded = foldSumDiff(lhs, rhs)) return folded; assert(!(lhs.isa() && rhs.isa()) && @@ -707,7 +704,7 @@ SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) { return addConstant(lhs.cast(), rhsConstant.getValue()); } -SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) { +SDBMExpr mlir::ops_assertions::operator-(SDBMExpr lhs, SDBMExpr rhs) { // Fold x - x == 0. if (lhs == rhs) return SDBMConstantExpr::get(lhs.getDialect(), 0); @@ -734,7 +731,7 @@ SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) { return buildDiffExpr(lhs.cast(), (-rhs).cast()); } -SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) { +SDBMExpr mlir::ops_assertions::stripe(SDBMExpr expr, SDBMExpr factor) { auto constantFactor = factor.cast(); assert(constantFactor.getValue() > 0 && "non-positive stripe"); @@ -744,6 +741,3 @@ SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor) { return SDBMStripeExpr::get(expr.cast(), constantFactor); } - -} // namespace ops_assertions -} // namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp index 1fd6274..be486f8 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -93,6 +93,7 @@ class DecorateSPIRVCompositeTypeLayoutPass private: void runOnModule() override; }; +} // namespace void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() { auto module = getModule(); @@ -120,7 +121,6 @@ void DecorateSPIRVCompositeTypeLayoutPass::runOnModule() { } } } -} // namespace std::unique_ptr> mlir::spirv::createDecorateSPIRVCompositeTypeLayoutPass() { diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index bbee80a..5098ba8 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -63,14 +63,12 @@ using llvm::orc::RTDyldObjectLinkingLayer; using llvm::orc::ThreadSafeModule; using llvm::orc::TMOwningSimpleCompiler; -// Wrap a string into an llvm::StringError. -static inline Error make_string_error(const Twine &message) { +/// Wrap a string into an llvm::StringError. +static Error make_string_error(const Twine &message) { return llvm::make_error(message.str(), llvm::inconvertibleErrorCode()); } -namespace mlir { - void SimpleObjectCache::notifyObjectCompiled(const Module *M, MemoryBufferRef ObjBuffer) { cachedObjects[M->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy( @@ -316,4 +314,3 @@ Error ExecutionEngine::invoke(StringRef name, MutableArrayRef args) { return Error::success(); } -} // end namespace mlir diff --git a/mlir/lib/Quantizer/Support/Statistics.cpp b/mlir/lib/Quantizer/Support/Statistics.cpp index d155875..6753898 100644 --- a/mlir/lib/Quantizer/Support/Statistics.cpp +++ b/mlir/lib/Quantizer/Support/Statistics.cpp @@ -95,15 +95,10 @@ bool AttributeTensorStatistics::get(TensorAxisStatistics &stats) const { return false; } -namespace mlir { -namespace quantizer { - -raw_ostream &operator<<(raw_ostream &os, const TensorAxisStatistics &stats) { +raw_ostream &mlir::quantizer::operator<<(raw_ostream &os, + const TensorAxisStatistics &stats) { os << "STATS[sampleSize=" << stats.sampleSize << ", min=" << stats.minValue << ", maxValue=" << stats.maxValue << ", mean=" << stats.mean << ", variance=" << stats.variance << "]"; return os; } - -} // end namespace quantizer -} // end namespace mlir diff --git a/mlir/lib/Quantizer/Support/UniformSolvers.cpp b/mlir/lib/Quantizer/Support/UniformSolvers.cpp index bd2fe68..77d69be 100644 --- a/mlir/lib/Quantizer/Support/UniformSolvers.cpp +++ b/mlir/lib/Quantizer/Support/UniformSolvers.cpp @@ -127,16 +127,15 @@ double UniformParamsFromMinMaxSolver::dequantize(int64_t xq) const { return (xq - zp) * delta; } -namespace mlir { -namespace quantizer { - -raw_ostream &operator<<(raw_ostream &os, const UniformStorageParams &p) { +raw_ostream &mlir::quantizer::operator<<(raw_ostream &os, + const UniformStorageParams &p) { os << "UniformStorageParams{" << p.numLevels << ", " << p.minValue << "}"; return os; } -raw_ostream &operator<<(raw_ostream &os, - const UniformParamsFromMinMaxSolver &s) { +raw_ostream & +mlir::quantizer::operator<<(raw_ostream &os, + const UniformParamsFromMinMaxSolver &s) { os << "UniformParamsFromMinMaxSolver(" << s.getStepCount() << "){"; os << "(" << s.getBoundingMin() << ":" << s.getBoundingMax() << ") -> "; if (!s.isSatisfied()) { @@ -151,6 +150,3 @@ raw_ostream &operator<<(raw_ostream &os, return os; } - -} // end namespace quantizer -} // end namespace mlir diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp index 83c4869..8baed98 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -36,7 +36,6 @@ using namespace mlir; -namespace { static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder, llvm::Intrinsic::ID intrinsic, ArrayRef args = {}) { @@ -56,6 +55,7 @@ static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType, : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32; } +namespace { class ModuleTranslation : public LLVM::ModuleTranslation { public: diff --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp index c06e1ca..34786fb 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp @@ -39,7 +39,6 @@ using namespace mlir; -namespace { // Create a call to llvm intrinsic static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder, llvm::Intrinsic::ID intrinsic, @@ -67,6 +66,7 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilder<> &builder, return builder.CreateCall(fn, ArrayRef(fn_op0)); } +namespace { class ModuleTranslation : public LLVM::ModuleTranslation { public: diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 086c3a8..6206a88 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -36,13 +36,13 @@ #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/Cloning.h" -namespace mlir { -namespace LLVM { +using namespace mlir; +using namespace mlir::LLVM; -// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. -// This currently supports integer, floating point, splat and dense element -// attributes and combinations thereof. In case of error, report it to `loc` -// and return nullptr. +/// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. +/// This currently supports integer, floating point, splat and dense element +/// attributes and combinations thereof. In case of error, report it to `loc` +/// and return nullptr. llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc) { @@ -94,7 +94,7 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, return nullptr; } -// Convert MLIR integer comparison predicate to LLVM IR comparison predicate. +/// Convert MLIR integer comparison predicate to LLVM IR comparison predicate. static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) { switch (p) { case LLVM::ICmpPredicate::eq: @@ -159,10 +159,10 @@ static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) { llvm_unreachable("incorrect comparison predicate"); } -// Given a single MLIR operation, create the corresponding LLVM IR operation -// using the `builder`. LLVM IR Builder does not have a generic interface so -// this has to be a long chain of `if`s calling different functions with a -// different number of arguments. +/// Given a single MLIR operation, create the corresponding LLVM IR operation +/// using the `builder`. LLVM IR Builder does not have a generic interface so +/// this has to be a long chain of `if`s calling different functions with a +/// different number of arguments. LogicalResult ModuleTranslation::convertOperation(Operation &opInst, llvm::IRBuilder<> &builder) { auto extractPosition = [](ArrayAttr attr) { @@ -232,9 +232,9 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst, << opInst.getName(); } -// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes -// to define values corresponding to the MLIR block arguments. These nodes -// are not connected to the source basic blocks, which may not exist yet. +/// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes +/// to define values corresponding to the MLIR block arguments. These nodes +/// are not connected to the source basic blocks, which may not exist yet. LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { llvm::IRBuilder<> builder(blockMapping[&bb]); @@ -268,7 +268,7 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { return success(); } -// Convert the LLVM dialect linkage type to LLVM IR linkage type. +/// Convert the LLVM dialect linkage type to LLVM IR linkage type. llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) { switch (linkage) { case LLVM::Linkage::Private: @@ -297,8 +297,8 @@ llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) { llvm_unreachable("unknown linkage type"); } -// Create named global variables that correspond to llvm.mlir.global -// definitions. +/// Create named global variables that correspond to llvm.mlir.global +/// definitions. void ModuleTranslation::convertGlobals() { for (auto op : getModuleBody(mlirModule).getOps()) { llvm::Type *type = op.getType().getUnderlyingType(); @@ -340,8 +340,8 @@ void ModuleTranslation::convertGlobals() { } } -// Get the SSA value passed to the current block from the terminator operation -// of its predecessor. +/// Get the SSA value passed to the current block from the terminator operation +/// of its predecessor. static Value *getPHISourceValue(Block *current, Block *pred, unsigned numArguments, unsigned index) { auto &terminator = *pred->getTerminator(); @@ -394,7 +394,7 @@ static void topologicalSortImpl(llvm::SetVector &blocks, Block *b) { } } -// Sort function blocks topologically. +/// Sort function blocks topologically. static llvm::SetVector topologicalSort(LLVMFuncOp f) { // For each blocks that has not been visited yet (i.e. that has no // predecessors), add it to the list and traverse its successors in DFS @@ -513,6 +513,3 @@ ModuleTranslation::prepareLLVMModule(Operation *m) { return llvmModule; } - -} // namespace LLVM -} // namespace mlir diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index 6627e73..5694c99 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -118,6 +118,14 @@ mlir::createLoopFusionPass(unsigned fastMemorySpace, maximalFusion); } +// TODO(b/117228571) Replace when this is modeled through side-effects/op traits +static bool isMemRefDereferencingOp(Operation &op) { + if (isa(op) || isa(op) || + isa(op) || isa(op)) + return true; + return false; +} + namespace { // LoopNestStateCollector walks loop nests and collects load and store @@ -142,14 +150,6 @@ struct LoopNestStateCollector { } }; -// TODO(b/117228571) Replace when this is modeled through side-effects/op traits -static bool isMemRefDereferencingOp(Operation &op) { - if (isa(op) || isa(op) || - isa(op) || isa(op)) - return true; - return false; -} - // MemRefDependenceGraph is a graph data structure where graph nodes are // top-level operations in a FuncOp which contain load/store ops, and edges // are memref dependences between the nodes. @@ -674,6 +674,8 @@ public: void dump() const { print(llvm::errs()); } }; +} // end anonymous namespace + // Initializes the data dependence graph by walking operations in 'f'. // Assigns each node in the graph a node id based on program order in 'f'. // TODO(andydavis) Add support for taking a Block arg to construct the @@ -872,7 +874,7 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { } // TODO(mlir-team): improve/complete this when we have target data. -unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { +static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); unsigned sizeInBits; @@ -1373,6 +1375,8 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, return true; } +namespace { + // GreedyFusion greedily fuses loop nests which have a producer/consumer or // input-reuse relationship on a memref, with the goal of improving locality. // -- 2.7.4