[mlir][linalg][bufferize][NFC] Remove remaining Comprehensive Bufferize code
authorMatthias Springer <springerm@google.com>
Tue, 3 May 2022 14:40:13 +0000 (23:40 +0900)
committerMatthias Springer <springerm@google.com>
Wed, 4 May 2022 08:19:44 +0000 (17:19 +0900)
This commit removes the Linalg Comprehensive Bufferize pass.

Differential Revision: https://reviews.llvm.org/D124854

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp [deleted file]
mlir/lib/Dialect/Linalg/Transforms/InitTensorElimination.cpp [new file with mode: 0644]
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir [deleted file]
mlir/test/Dialect/Linalg/one-shot-bufferize-aliasing-in.mlir [moved from mlir/test/Dialect/Linalg/comprehensive-module-bufferize-aliasing-in.mlir with 95% similarity]
mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-aliasing-in.mlir [moved from mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-aliasing-in.mlir with 94% similarity]
mlir/test/Integration/Dialect/Linalg/CPU/test-one-shot-bufferize.mlir [moved from mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir with 98% similarity]

index 0c7a2b8..37d4a3e 100644 (file)
@@ -236,6 +236,10 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
     Option<"allowUnknownOps", "allow-unknown-ops", "bool",
            /*default=*/"false",
            "Allows unknown (not bufferizable) ops in the input IR.">,
+    Option<"alwaysAliasingWithDest", "always-aliasing-with-dest", "bool",
+            /*default=*/"true",
+            "Tensor OpResult cannot bufferize inplace OpOperands other than "
+            "out/dest OpOperands (if the op has such operands; experimental)">,
     Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned",
            /*default=*/"0",
            "Test only: Analyze ops in random order with a given seed (fuzzer)">,
index 3510b2f..93d01a7 100644 (file)
@@ -62,17 +62,6 @@ createConvertLinalgToParallelLoopsPass();
 std::unique_ptr<OperationPass<func::FuncOp>>
 createConvertLinalgToAffineLoopsPass();
 
-/// This pass implements a cross-dialect bufferization approach and performs an
-/// analysis to determine which op operands and results may be bufferized in the
-/// same buffers. The analysis is performed on topologically sorted CallOp and
-/// FuncOp within a module. It provides analyses and bufferization across
-/// function boundaries. Within a function boundary, the analysis is performed
-/// on SSA use-def chains starting from function operands that are annotated
-/// with the 'inplaceable' attribute.
-std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass();
-std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass(
-    const bufferization::OneShotBufferizationOptions &options);
-
 /// Create a pass that tries to eliminate init_tensor ops that are anchored on
 /// insert_slice ops.
 std::unique_ptr<Pass> createLinalgInitTensorEliminationPass();
index 2c0287d..da48c65 100644 (file)
@@ -24,51 +24,6 @@ def ConvertElementwiseToLinalg : Pass<"convert-elementwise-to-linalg", ""> {
   let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"];
 }
 
-def LinalgComprehensiveModuleBufferize :
-    Pass<"linalg-comprehensive-module-bufferize", "ModuleOp"> {
-  let summary = "Bufferize (tensor into memref) for a Module.";
-  let description = [{
-    This pass implements a cross-dialect bufferization approach and performs an
-    analysis to determine which op operands and results may be bufferized in the
-    same buffers. The analysis is performed on topologically sorted CallOp and
-    FuncOp within a module. It provides analyses and bufferization across
-    function boundaries. Within a function boundary, the analysis is performed
-    on SSA use-def chains starting from function operands that are annotated
-    with the 'inplaceable' attribute.
-  }];
-  let options = [
-    Option<"testAnalysisOnly", "test-analysis-only", "bool",
-            /*default=*/"false",
-           "Only runs inplaceability analysis (for testing purposes only)">,
-    Option<"printConflicts", "print-conflicts", "bool",
-            /*default=*/"false",
-           "Annotates IR with RaW conflicts. Requires test-analysis-only.">,
-    Option<"allowReturnAllocs", "allow-return-allocs", "bool",
-            /*default=*/"false",
-           "Allows returning/yielding new allocations from a block.">,
-    Option<"allowUnknownOps", "allow-unknown-ops", "bool",
-           /*default=*/"false",
-           "Allows unknown (not bufferizable) ops in the input IR.">,
-    Option<"alwaysAliasingWithDest", "always-aliasing-with-dest", "bool",
-            /*default=*/"true",
-            "Tensor OpResult cannot bufferize inplace OpOperands other than "
-            "out or dest OpOperands (if the op has a notion of such operands)">,
-    Option<"useAlloca", "use-alloca", "bool",
-           /*default=*/"false",
-           "Use stack allocations for memrefs (for testing purposes only)">,
-    Option<"fullyDynamicLayoutMaps", "fully-dynamic-layout-maps", "bool",
-           /*default=*/"true",
-           "Generate MemRef types with dynamic offset+strides by default.">,
-    Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned",
-           /*default=*/"0",
-           "Analyze ops in random order with a given seed (fuzzer)">,
-    Option<"createDeallocs", "create-deallocs", "bool", /*default=*/"true",
-           "Specify if buffers should be deallocated. For compatibility with "
-           "core bufferization passes.">,
-  ];
-  let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()";
-}
-
 def LinalgInitTensorElimination : Pass<"linalg-eliminate-init-tensors"> {
   let summary = "Try to eliminate all init_tensor ops.";
   let description = [{
index 29026eb..a16b191 100644 (file)
@@ -171,6 +171,7 @@ struct OneShotBufferizePass
       // pass.
       opt.allowReturnAllocs = allowReturnAllocs;
       opt.allowUnknownOps = allowUnknownOps;
+      opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
       opt.analysisFuzzerSeed = analysisFuzzerSeed;
       opt.createDeallocs = createDeallocs;
       opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
index 943efc5..2955391 100644 (file)
@@ -3,7 +3,6 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   CodegenStrategy.cpp
-  ComprehensiveBufferizePass.cpp
   ConstantFold.cpp
   Detensorize.cpp
   DropUnitDims.cpp
@@ -14,6 +13,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Generalization.cpp
   Hoisting.cpp
   HoistPadding.cpp
+  InitTensorElimination.cpp
   InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
deleted file mode 100644 (file)
index bbb013d..0000000
+++ /dev/null
@@ -1,161 +0,0 @@
-//===- ComprehensiveBufferize.cpp - Single pass bufferization -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "PassDetail.h"
-
-#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
-#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
-#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassManager.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "mlir/Transforms/Passes.h"
-
-using namespace mlir;
-using namespace mlir::bufferization;
-using namespace mlir::linalg;
-
-namespace {
-struct LinalgComprehensiveModuleBufferize
-    : public LinalgComprehensiveModuleBufferizeBase<
-          LinalgComprehensiveModuleBufferize> {
-  LinalgComprehensiveModuleBufferize() = default;
-
-  LinalgComprehensiveModuleBufferize(
-      const LinalgComprehensiveModuleBufferize &p) = default;
-
-  explicit LinalgComprehensiveModuleBufferize(
-      const OneShotBufferizationOptions &options)
-      : options(options) {}
-
-  void runOnOperation() override;
-
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<bufferization::BufferizationDialect, linalg::LinalgDialect,
-                memref::MemRefDialect, tensor::TensorDialect,
-                vector::VectorDialect, scf::SCFDialect,
-                arith::ArithmeticDialect, func::FuncDialect, AffineDialect>();
-    arith::registerBufferizableOpInterfaceExternalModels(registry);
-    bufferization::registerAllocationOpInterfaceExternalModels(registry);
-    linalg::registerBufferizableOpInterfaceExternalModels(registry);
-    scf::registerBufferizableOpInterfaceExternalModels(registry);
-    func_ext::registerBufferizableOpInterfaceExternalModels(registry);
-    tensor::registerBufferizableOpInterfaceExternalModels(registry);
-    vector::registerBufferizableOpInterfaceExternalModels(registry);
-  }
-
-private:
-  llvm::Optional<OneShotBufferizationOptions> options;
-};
-
-struct LinalgInitTensorElimination
-    : public LinalgInitTensorEliminationBase<LinalgInitTensorElimination> {
-  LinalgInitTensorElimination() = default;
-
-  void runOnOperation() override;
-
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<linalg::LinalgDialect, tensor::TensorDialect>();
-  }
-};
-} // namespace
-
-static void applyEnablingTransformations(ModuleOp moduleOp) {
-  RewritePatternSet patterns(moduleOp.getContext());
-  patterns.add<GeneralizePadOpPattern>(moduleOp.getContext());
-  (void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
-}
-
-static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
-                                                MemRefType type,
-                                                ValueRange dynShape,
-                                                unsigned int bufferAlignment) {
-  Value allocated = b.create<memref::AllocaOp>(
-      loc, type, dynShape, b.getI64IntegerAttr(bufferAlignment));
-  return allocated;
-}
-
-void LinalgComprehensiveModuleBufferize::runOnOperation() {
-  OneShotBufferizationOptions opt;
-  if (!options) {
-    // Make new bufferization options if none were provided when creating the
-    // pass.
-    if (useAlloca) {
-      opt.allocationFn = allocationFnUsingAlloca;
-      opt.deallocationFn = [](OpBuilder &b, Location loc, Value v) {
-        return success();
-      };
-    }
-    opt.allowReturnAllocs = allowReturnAllocs;
-    opt.allowUnknownOps = allowUnknownOps;
-    opt.analysisFuzzerSeed = analysisFuzzerSeed;
-    opt.createDeallocs = createDeallocs;
-    opt.fullyDynamicLayoutMaps = fullyDynamicLayoutMaps;
-    opt.printConflicts = printConflicts;
-    opt.testAnalysisOnly = testAnalysisOnly;
-    opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
-    opt.bufferizeFunctionBoundaries = true;
-  } else {
-    opt = *options;
-  }
-
-  ModuleOp moduleOp = getOperation();
-  applyEnablingTransformations(moduleOp);
-
-  if (failed(runOneShotModuleBufferize(moduleOp, opt))) {
-    signalPassFailure();
-    return;
-  }
-
-  if (opt.testAnalysisOnly)
-    return;
-
-  OpPassManager cleanupPipeline("builtin.module");
-  cleanupPipeline.addPass(createCanonicalizerPass());
-  cleanupPipeline.addPass(createCSEPass());
-  cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
-  (void)runPipeline(cleanupPipeline, moduleOp);
-}
-
-void LinalgInitTensorElimination::runOnOperation() {
-  Operation *op = getOperation();
-  OneShotBufferizationOptions options;
-  OneShotAnalysisState state(op, options);
-  if (failed(analyzeOp(op, state))) {
-    signalPassFailure();
-    return;
-  }
-
-  IRRewriter rewriter(op->getContext());
-  if (failed(insertSliceAnchoredInitTensorEliminationStep(rewriter, op, state)))
-    signalPassFailure();
-}
-
-std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {
-  return std::make_unique<LinalgComprehensiveModuleBufferize>();
-}
-
-std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass(
-    const OneShotBufferizationOptions &options) {
-  return std::make_unique<LinalgComprehensiveModuleBufferize>(options);
-}
-
-std::unique_ptr<Pass> mlir::createLinalgInitTensorEliminationPass() {
-  return std::make_unique<LinalgInitTensorElimination>();
-}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/InitTensorElimination.cpp b/mlir/lib/Dialect/Linalg/Transforms/InitTensorElimination.cpp
new file mode 100644 (file)
index 0000000..f48f9c8
--- /dev/null
@@ -0,0 +1,50 @@
+//===- ComprehensiveBufferize.cpp - Single pass bufferization -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::linalg;
+
+namespace {
+struct LinalgInitTensorElimination
+    : public LinalgInitTensorEliminationBase<LinalgInitTensorElimination> {
+  LinalgInitTensorElimination() = default;
+
+  void runOnOperation() override;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<linalg::LinalgDialect, tensor::TensorDialect>();
+  }
+};
+} // namespace
+
+void LinalgInitTensorElimination::runOnOperation() {
+  Operation *op = getOperation();
+  OneShotBufferizationOptions options;
+  OneShotAnalysisState state(op, options);
+  if (failed(analyzeOp(op, state))) {
+    signalPassFailure();
+    return;
+  }
+
+  IRRewriter rewriter(op->getContext());
+  if (failed(insertSliceAnchoredInitTensorEliminationStep(rewriter, op, state)))
+    signalPassFailure();
+}
+
+std::unique_ptr<Pass> mlir::createLinalgInitTensorEliminationPass() {
+  return std::make_unique<LinalgInitTensorElimination>();
+}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
deleted file mode 100644 (file)
index 88613a2..0000000
+++ /dev/null
@@ -1,65 +0,0 @@
-// RUN: mlir-opt %s -pass-pipeline="linalg-comprehensive-module-bufferize{allow-return-allocs use-alloca}" -split-input-file | FileCheck %s
-
-//  CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)>
-//  CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-
-//      CHECK:  func @init_and_dot(
-// CHECK-SAME:    %[[A:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
-// CHECK-SAME:    %[[B:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
-// CHECK-SAME:    %[[C:[a-zA-Z0-9]*]]: memref<f32, #[[$DYN_0D_MAP]]>
-func.func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
-  // CHECK-NEXT:   %[[C0:.*]] = arith.constant 0{{.*}} : f32
-  %v0 = arith.constant 0.0 : f32
-
-  // CHECK-NEXT:   linalg.fill ins(%[[C0]] : f32) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
-  %d = linalg.fill ins(%v0 : f32) outs(%c : tensor<f32>) -> tensor<f32>
-
-  // CHECK-NEXT:   linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
-  %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
-    outs(%d: tensor<f32>) -> tensor<f32>
-
-  // CHECK-NEXT:   return
-  return %e : tensor<f32>
-}
-
-//      CHECK:  func @main()
-func.func @main() {
-  //  CHECK-DAG:   %[[C0:.*]] = arith.constant 0{{.*}} : f32
-  //  CHECK-DAG:   %[[C1:.*]] = arith.constant 1{{.*}} : f32
-  //  CHECK-DAG:   %[[C2:.*]] = arith.constant 2{{.*}} : f32
-  %v0 = arith.constant 0.0 : f32
-  %v1 = arith.constant 1.0 : f32
-  %v2 = arith.constant 2.0 : f32
-
-  // CHECK-NEXT:   %[[A:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
-  // CHECK-NEXT:   %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
-  // CHECK-NEXT:   %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref<f32>
-  //  CHECK-DAG:   %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
-  //  CHECK-DAG:   %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
-  //  CHECK-DAG:   %[[cC:.*]] = memref.cast %[[C]] : memref<f32> to memref<f32, #[[$DYN_0D_MAP]]>
-  %A = linalg.init_tensor [64] : tensor<64xf32>
-  %B = linalg.init_tensor [64] : tensor<64xf32>
-  %C = linalg.init_tensor [] : tensor<f32>
-
-  //  CHECK-DAG:   linalg.fill ins(%[[C1]] : f32) outs(%[[A]] : memref<64xf32>)
-  //  CHECK-DAG:   linalg.fill ins(%[[C2]] : f32) outs(%[[B]] : memref<64xf32>)
-  //  CHECK-DAG:   linalg.fill ins(%[[C0]] : f32) outs(%[[C]] : memref<f32>)
-  %AA = linalg.fill ins(%v1 : f32) outs(%A : tensor<64xf32>) -> tensor<64xf32>
-  %BB = linalg.fill ins(%v2 : f32) outs(%B : tensor<64xf32>) -> tensor<64xf32>
-  %CC = linalg.fill ins(%v0 : f32) outs(%C : tensor<f32>) -> tensor<f32>
-
-  // CHECK-NEXT:   call @init_and_dot(%[[cA]], %[[cB]], %[[cC]])
-  %res = call @init_and_dot(%AA, %BB, %CC) :
-    (tensor<64xf32>, tensor<64xf32>, tensor<f32>) -> tensor<f32>
-
-  // CHECK-NEXT:   %[[dC:.*]] = memref.cast %[[C]] : memref<f32> to memref<*xf32>
-  %res2 = tensor.cast %res: tensor<f32> to tensor<*xf32>
-
-  // CHECK-NEXT:   call @print_memref_f32(%[[dC]]) : (memref<*xf32>) -> ()
-  call @print_memref_f32(%res2) : (tensor<*xf32>) -> ()
-
-  return
-}
-
-//     CHECK:   func private @print_memref_f32(memref<*xf32>)
-func.func private @print_memref_f32(tensor<*xf32>)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs always-aliasing-with-dest=0" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs always-aliasing-with-dest=0" -split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @linalg_op_bufferizes_inplace_with_input
 //  CHECK-SAME:     %[[t1:.*]]: memref<?x?xf32, #{{.*}}>, %[[t2:.*]]: memref<?xf32, #{{.*}}>, %[[t3:.*]]: memref<?x?xf32, #{{.*}}>
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-allocs always-aliasing-with-dest=0" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs always-aliasing-with-dest=0" -split-input-file | FileCheck %s
 
 // This is a test case for alwaysAliasingWithDest = 0. In that case, an OpResult
 // may bufferize in-place with an "in" OpOperand or any non-"out" OpOperand.
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="func.func(canonicalize,cse),linalg-comprehensive-module-bufferize" |\
+// RUN: mlir-opt %s -pass-pipeline="func.func(canonicalize,cse),one-shot-bufferize{bufferize-function-boundaries}" |\
 // RUN: mlir-opt -pass-pipeline="func.func(buffer-deallocation,convert-vector-to-scf,lower-affine,convert-linalg-to-loops)" |\
 // RUN: mlir-opt -pass-pipeline="func.func(canonicalize,convert-scf-to-cf),convert-vector-to-llvm,convert-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts" | \
 
@@ -22,7 +22,7 @@ func.func @init_and_dot(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %arg2: ten
     %9 = tensor.extract_slice %arg1[%arg3] [2] [1] : tensor<64xf32> to tensor<2xf32>
     %10 = tensor.cast %9 : tensor<2xf32> to tensor<?xf32>
     %11 = tensor.pad %10 low[%c0] high[%c0]  {
-    ^bb0(%arg5: index):  
+    ^bb0(%arg5: index):
       tensor.yield %cst : f32
     } : tensor<?xf32> to tensor<2xf32>
     %12 = tensor.insert_slice %11 into %arg4[%8, 0] [1, 2] [1, 1] : tensor<2xf32> into tensor<?x2xf32>