[mlir][tosa] Switch TosaFoldConstantTranspose to use ElementsAttr.
authorJacques Pienaar <jpienaar@google.com>
Mon, 22 Aug 2022 22:45:23 +0000 (15:45 -0700)
committerJacques Pienaar <jpienaar@google.com>
Mon, 22 Aug 2022 22:45:23 +0000 (15:45 -0700)
Also avoid redoing index calculation.

Differential Revision: https://reviews.llvm.org/D132274

mlir/lib/Dialect/Tosa/Transforms/TosaFoldConstantTranspose.cpp

index f86605d..3147073 100644 (file)
@@ -30,7 +30,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
     if (!outputType.getElementType().isIntOrIndexOrFloat())
       return failure();
 
-    DenseElementsAttr inputValues;
+    ElementsAttr inputValues;
     if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
       return failure();
     // Make sure the input is a constant that has a single user.
@@ -57,10 +57,9 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
     // index.
     auto attrValues = inputValues.getValues<Attribute>();
     ArrayRef<int64_t> outputShape = outputType.getShape();
-    for (int srcLinearIndex = 0; srcLinearIndex < numElements;
-         ++srcLinearIndex) {
+    for (const auto &it : llvm::enumerate(attrValues)) {
       SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
-      int totalCount = srcLinearIndex;
+      int totalCount = it.index();
       for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
         srcIndices[dim] = totalCount % inputShape[dim];
         totalCount /= inputShape[dim];
@@ -74,7 +73,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
       for (int dim = 1; dim < outputType.getRank(); ++dim)
         dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];
 
-      outputValues[dstLinearIndex] = attrValues[srcIndices];
+      outputValues[dstLinearIndex] = it.value();
     }
 
     rewriter.replaceOpWithNewOp<tosa::ConstOp>(