[mlir][tosa] Fold tosa.reshape with splat values
authorRob Suderman <suderman@google.com>
Tue, 30 Aug 2022 00:05:23 +0000 (17:05 -0700)
committerRob Suderman <suderman@google.com>
Tue, 30 Aug 2022 00:18:03 +0000 (17:18 -0700)
Folding reshapes of splats is trivial and should be canonicalized
away.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D132760

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
mlir/test/Dialect/Tosa/constant-op-fold.mlir

index 6af5b3d..6d27d1e 100644 (file)
@@ -717,9 +717,18 @@ OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) {
   auto inputTy = getInput1().getType().dyn_cast<RankedTensorType>();
   auto outputTy = getType().dyn_cast<RankedTensorType>();
 
-  if (!inputTy || !outputTy || inputTy != outputTy)
+  if (!inputTy || !outputTy)
     return {};
-  return getInput1();
+
+  if (inputTy == outputTy)
+    return getInput1();
+
+  auto operand = operands[0].dyn_cast_or_null<DenseElementsAttr>();
+  if (operand && outputTy.hasStaticShape() && operand.isSplat()) {
+    return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
+  }
+
+  return {};
 }
 
 OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
index f5dad46..af24d59 100644 (file)
@@ -398,6 +398,16 @@ func.func @fold_greater_splat_i32_true() -> tensor<10xi1> {
 
 // -----
 
+func.func @reshape_splat() -> tensor<6x5x4xi32> {
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() {value = dense<42> : tensor<6x5x4xi32>}
+  %splat = "tosa.const"() {value = dense<42> : tensor<4x5x6xi32>} : () -> tensor<4x5x6xi32>
+  %reshape = "tosa.reshape"(%splat) { new_shape = [6, 5, 4] } : (tensor<4x5x6xi32>) -> tensor<6x5x4xi32>
+  // CHECK: return %[[SPLAT]]
+  return %reshape : tensor<6x5x4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: @slice_splat
 func.func @slice_splat() -> tensor<1x1x1xi32> {
   // CHECK: %[[SLICE:.+]] = "tosa.const"() {value = dense<42> : tensor<1x1x1xi32>}