From c788cad83b6b5c24f8160f9fc11a69dd7beafb8b Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 28 Oct 2021 09:45:07 -0400 Subject: [PATCH] [mlir][linalg] Fix FoldConstantTranspose execution inefficiency * Move SmallVectors outside of inner loops to avoid frequent allocations and deallocations * Calculate linearized index and call flat range getters to avoid internal shape querying behind `getValue`. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D112099 --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 126 +++++++++++++++------ 1 file changed, 93 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 32ad335..ee5622d2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1286,12 +1286,16 @@ private: template class FoldConstantBase : public OpRewritePattern { public: + struct APIntOrFloat { + Optional apInt; + Optional apFloat; + }; struct APIntOrFloatArray { SmallVector apInts; SmallVector apFloats; }; using RegionComputationFn = - std::function; + std::function; FoldConstantBase(MLIRContext *context, const ControlElementwiseOpsFusionFn &controlFn, @@ -1403,57 +1407,109 @@ public: auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); auto outputShape = outputType.getShape(); - // Transpose the input constant. Because we don't know its rank in advance, - // we need to loop over the range [0, element count) and delinearize the - // index. - for (int linearIndex0 = 0; linearIndex0 < numElements; ++linearIndex0) { - SmallVector indices(loopBounds.size(), 0); - int totalCount = linearIndex0; + // Allocate small vectors for index delinearization. Initial values do not + // matter here as they will be overwritten later. + SmallVector indices(loopBounds.size(), 0); + SmallVector dstIndices(loopBounds.size(), 0); + SmallVector> srcIndices( + numInputs, SmallVector(loopBounds.size(), 0)); + SmallVector srcLinearIndices(numInputs, 0); + uint64_t dstLinearIndex = 0; + + // Allocate spaces for compute function inputs. Initial values do not matter + // here as they will be overwritten later. + APIntOrFloatArray computeFnInputs; + + auto inputShapes = llvm::to_vector<4>( + llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) { + return operand->get().getType().cast().getShape(); + })); + + // Given a `linearIndex`, remap it to a linear index to access linalg op + // inputs/ouputs. This mutates `indices`, `srcIndices`, `dstIndices`, + // `srcLinearIndices`, `dstLinearIndex` in place. + auto computeRemappedLinearIndex = [&](int linearIndex) { + int totalCount = linearIndex; for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { indices[dim] = totalCount % loopBounds[dim]; totalCount /= loopBounds[dim]; } - SmallVector> srcIndices; - for (int i = 0; i < numInputs; ++i) - srcIndices.emplace_back(loopBounds.size(), 0); - SmallVector dstIndices(loopBounds.size(), 0); - for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { for (int i = 0; i < numInputs; ++i) srcIndices[i][dim] = indices[inputDims[i][dim]]; dstIndices[dim] = indices[outputDims[dim]]; } - uint64_t linearIndex1 = dstIndices.front(); - for (int dim = 1; dim < outputType.getRank(); ++dim) - linearIndex1 = linearIndex1 * outputShape[dim] + dstIndices[dim]; + dstLinearIndex = dstIndices.front(); + for (int i = 0; i < numInputs; ++i) + srcLinearIndices[i] = srcIndices[i].front(); - // Collect constant elements for all inputs at this loop iteration. - SmallVector intValues; - SmallVector fpValues; - if (elementType.isa()) { + for (int dim = 1; dim < outputType.getRank(); ++dim) { + dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim]; for (int i = 0; i < numInputs; ++i) - fpValues.push_back(inputValues[i].getValue(srcIndices[i])); - } else { - for (int i = 0; i < numInputs; ++i) - intValues.push_back(inputValues[i].getValue(srcIndices[i])); + srcLinearIndices[i] = + srcLinearIndices[i] * inputShapes[i][dim] + srcIndices[i][dim]; + } + }; + + bool isFloat = elementType.isa(); + if (isFloat) { + SmallVector> + inputFpIterators; + for (int i = 0; i < numInputs; ++i) + inputFpIterators.push_back(inputValues[i].getValues()); + + computeFnInputs.apFloats.resize(numInputs, APFloat(0.f)); + + // Transpose the input constant. Because we don't know its rank in + // advance, we need to loop over the range [0, element count) and + // delinearize the index. + for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { + computeRemappedLinearIndex(linearIndex); + + // Collect constant elements for all inputs at this loop iteration. + for (int i = 0; i < numInputs; ++i) { + computeFnInputs.apFloats[i] = + *(inputFpIterators[i].begin() + srcLinearIndices[i]); + } + + // Invoke the computation to get the corresponding constant output + // element. + APIntOrFloat outputs = computeFn(computeFnInputs); + + fpOutputValues[dstLinearIndex] = outputs.apFloat.getValue(); } + } else { + SmallVector> + inputIntIterators; + for (int i = 0; i < numInputs; ++i) + inputIntIterators.push_back(inputValues[i].getValues()); + + computeFnInputs.apInts.resize(numInputs); + + // Transpose the input constant. Because we don't know its rank in + // advance, we need to loop over the range [0, element count) and + // delinearize the index. + for (int linearIndex = 0; linearIndex < numElements; ++linearIndex) { + computeRemappedLinearIndex(linearIndex); - // Invoke the computation to get the corresponding constant output - // element. - APIntOrFloatArray inputs = {intValues, fpValues}; - APIntOrFloatArray outputs = computeFn(inputs); + // Collect constant elements for all inputs at this loop iteration. + for (int i = 0; i < numInputs; ++i) { + computeFnInputs.apInts[i] = + *(inputIntIterators[i].begin() + srcLinearIndices[i]); + } + + // Invoke the computation to get the corresponding constant output + // element. + APIntOrFloat outputs = computeFn(computeFnInputs); - if (elementType.isa()) { - fpOutputValues[linearIndex1] = outputs.apFloats.front(); - } else { - intOutputValues[linearIndex1] = outputs.apInts.front(); + intOutputValues[dstLinearIndex] = outputs.apInt.getValue(); } } DenseIntOrFPElementsAttr outputAttr; - if (elementType.isa()) { + if (isFloat) { outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues); } else { outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues); @@ -1494,7 +1550,11 @@ struct FoldConstantTranspose : public FoldConstantBase { } // No computation; just return the orginal value. - return [](APIntOrFloatArray inputs) { return inputs; }; + return [](const APIntOrFloatArray &inputs) { + if (inputs.apFloats.empty()) + return APIntOrFloat{inputs.apInts.front(), llvm::None}; + return APIntOrFloat{llvm::None, inputs.apFloats.front()}; + }; } ControlElementwiseOpsFusionFn controlFn; -- 2.7.4