[MLIR][TOSA] Resubmit Tosa to Standard/SCF Lowerings (const, if, while)"
authorRob Suderman <rob.suderman@gmail.com>
Fri, 26 Feb 2021 02:08:29 +0000 (18:08 -0800)
committerRob Suderman <rob.suderman@gmail.com>
Sat, 27 Feb 2021 01:44:12 +0000 (17:44 -0800)
Includes a lowering for tosa.const, tosa.if, and tosa.while to Standard/SCF dialects. TosaToStandard is
used for constant lowerings and TosaToSCF handles the if/while ops.

Resubmission of https://reviews.llvm.org/D97518 with ASAN fixes.

Differential Revision: https://reviews.llvm.org/D97529

14 files changed:
mlir/include/mlir/Conversion/Passes.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h [new file with mode: 0644]
mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h [new file with mode: 0644]
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/PassDetail.h
mlir/lib/Conversion/TosaToSCF/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp [new file with mode: 0644]
mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp [new file with mode: 0644]
mlir/lib/Conversion/TosaToStandard/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp [new file with mode: 0644]
mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp [new file with mode: 0644]
mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir [new file with mode: 0644]
mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir [new file with mode: 0644]

index 121dae6..21e604e 100644 (file)
@@ -31,6 +31,8 @@
 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
+#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
+#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
index aa22878..f372838 100644 (file)
@@ -441,6 +441,36 @@ def TosaToLinalgOnTensors : FunctionPass<"tosa-to-linalg-on-tensors"> {
 }
 
 //===----------------------------------------------------------------------===//
+// TosaToSCF
+//===----------------------------------------------------------------------===//
+
+def TosaToSCF : Pass<"tosa-to-scf"> {
+  let summary = "Lower TOSA to the SCF dialect";
+  let dependentDialects = ["tensor::TensorDialect, scf::SCFDialect"];
+  let description = [{
+    Pass that converts TOSA's control flow operations to the equivalent SCF
+    operations.
+  }];
+
+  let constructor = "tosa::createTosaToSCF()";
+}
+
+//===----------------------------------------------------------------------===//
+// TosaToStandard
+//===----------------------------------------------------------------------===//
+
+def TosaToStandard : Pass<"tosa-to-standard"> {
+  let summary = "Lower TOSA to the Standard dialect";
+  let dependentDialects = ["StandardOpsDialect"];
+  let description = [{
+    Pass that converts TOSA operations to the equivalent operations using the
+    operations in the Standard dialect.
+  }];
+
+  let constructor = "tosa::createTosaToStandard()";
+}
+
+//===----------------------------------------------------------------------===//
 // VectorToSCF
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h b/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
new file mode 100644 (file)
index 0000000..68ed0e0
--- /dev/null
@@ -0,0 +1,32 @@
+//===-- TosaToSCF.h - TOSA to SCF dialect lowerings -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the passes for the TOSA to SCF Dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H
+#define MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaToSCF();
+
+void populateTosaToSCFConversionPatterns(MLIRContext *context,
+                                         OwningRewritePatternList *patterns);
+
+/// Populates passes to convert from TOSA to SCF.
+void addTosaToSCFPasses(OpPassManager &pm);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H
diff --git a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
new file mode 100644 (file)
index 0000000..8255500
--- /dev/null
@@ -0,0 +1,32 @@
+//===-- TosaToStandard.h - TOSA optimization pass declarations --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the passes for the TOSA to Standard Dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
+#define MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace tosa {
+
+std::unique_ptr<Pass> createTosaToStandard();
+
+void populateTosaToStandardConversionPatterns(
+    MLIRContext *context, OwningRewritePatternList *patterns);
+
+/// Populates passes to convert from TOSA to Standard.
+void addTosaToStandardPasses(OpPassManager &pm);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
index 6ba8d41..2f80084 100644 (file)
@@ -22,6 +22,8 @@ add_subdirectory(SPIRVToLLVM)
 add_subdirectory(StandardToLLVM)
 add_subdirectory(StandardToSPIRV)
 add_subdirectory(TosaToLinalg)
+add_subdirectory(TosaToSCF)
+add_subdirectory(TosaToStandard)
 add_subdirectory(ArmSVEToLLVM)
 add_subdirectory(VectorToROCDL)
 add_subdirectory(VectorToLLVM)
index c0e1791..7c1db73 100644 (file)
@@ -59,6 +59,14 @@ namespace spirv {
 class SPIRVDialect;
 } // end namespace spirv
 
+namespace tensor {
+class TensorDialect;
+} // end namespace tensor
+
+namespace tosa {
+class TosaDialect;
+} // end namespace tosa
+
 namespace vector {
 class VectorDialect;
 } // end namespace vector
diff --git a/mlir/lib/Conversion/TosaToSCF/CMakeLists.txt b/mlir/lib/Conversion/TosaToSCF/CMakeLists.txt
new file mode 100644 (file)
index 0000000..189c25c
--- /dev/null
@@ -0,0 +1,21 @@
+add_mlir_conversion_library(MLIRTosaToSCF
+  TosaToSCF.cpp
+  TosaToSCFPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSCF
+  MLIRStandard
+  MLIRPass
+  MLIRTensor
+  MLIRTosa
+  MLIRTosaTransforms
+  MLIRSupport
+  )
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
new file mode 100644 (file)
index 0000000..55ed64b
--- /dev/null
@@ -0,0 +1,109 @@
+//===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These rewriters lower from the Tosa to the SCF dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+static void inlineIfCase(Region &srcRegion, Region &dstRegion,
+                         OperandRange operands, PatternRewriter &rewriter) {
+  rewriter.cloneRegionBefore(srcRegion, &dstRegion.front());
+  rewriter.eraseBlock(&dstRegion.back());
+
+  Block *headBlock = &dstRegion.front();
+  for (auto it : llvm::zip(headBlock->getArguments(), operands))
+    std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
+
+  auto yield = cast<YieldOp>(headBlock->getTerminator());
+  rewriter.setInsertionPoint(yield);
+  rewriter.create<scf::YieldOp>(yield.getLoc(), yield.inputs());
+  rewriter.eraseOp(yield);
+
+  headBlock->eraseArguments(
+      llvm::to_vector<4>(llvm::seq<unsigned>(0, headBlock->getNumArguments())));
+}
+
+static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
+                            PatternRewriter &rewriter, bool isCond) {
+  rewriter.cloneRegionBefore(srcRegion, &dstRegion.back());
+  rewriter.eraseBlock(&dstRegion.back());
+
+  Block *headBlock = &dstRegion.front();
+
+  auto yield = cast<YieldOp>(headBlock->getTerminator());
+  rewriter.setInsertionPoint(yield);
+  if (isCond) {
+    auto condition =
+        rewriter.create<tensor::ExtractOp>(yield.getLoc(), yield.getOperand(0));
+    rewriter.create<scf::ConditionOp>(yield.getLoc(), condition,
+                                      headBlock->getArguments());
+  } else {
+    rewriter.setInsertionPoint(yield);
+    rewriter.create<scf::YieldOp>(yield.getLoc(), yield.inputs());
+  }
+  rewriter.eraseOp(yield);
+}
+
+namespace {
+
+class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
+public:
+  using OpRewritePattern<tosa::IfOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::IfOp op,
+                                PatternRewriter &rewriter) const final {
+    auto condition = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.cond());
+    auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
+                                            condition, true);
+
+    inlineIfCase(op.then_branch(), newIf.thenRegion(), op.inputs(), rewriter);
+    inlineIfCase(op.else_branch(), newIf.elseRegion(), op.inputs(), rewriter);
+
+    rewriter.replaceOp(op, newIf.getResults());
+    return success();
+  }
+};
+
+class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
+public:
+  using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::WhileOp op,
+                                PatternRewriter &rewriter) const final {
+    auto newWhile = rewriter.create<scf::WhileOp>(
+        op.getLoc(), op.getResultTypes(), op.inputs());
+    rewriter.createBlock(&newWhile.before());
+    rewriter.createBlock(&newWhile.after());
+
+    inlineWhileCase(op.cond(), newWhile.before(), rewriter, true);
+    inlineWhileCase(op.body(), newWhile.after(), rewriter, false);
+
+    rewriter.replaceOp(op, newWhile.getResults());
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaToSCFConversionPatterns(
+    MLIRContext *context, OwningRewritePatternList *patterns) {
+  patterns->insert<IfOpConverter>(context);
+  patterns->insert<WhileOpConverter>(context);
+}
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
new file mode 100644 (file)
index 0000000..f403a46
--- /dev/null
@@ -0,0 +1,53 @@
+//===- TosaToSCFPass.cpp - Lowering Tosa to SCF Dialect -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the SCF dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+struct TosaToSCF : public TosaToSCFBase<TosaToSCF> {
+public:
+  void runOnOperation() override {
+    OwningRewritePatternList patterns;
+    ConversionTarget target(getContext());
+    target.addLegalDialect<tensor::TensorDialect, scf::SCFDialect>();
+    target.addIllegalOp<tosa::IfOp, tosa::WhileOp>();
+    target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+
+    auto *op = getOperation();
+    mlir::tosa::populateTosaToSCFConversionPatterns(op->getContext(),
+                                                    &patterns);
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToSCF() {
+  return std::make_unique<TosaToSCF>();
+}
+
+void mlir::tosa::addTosaToSCFPasses(OpPassManager &pm) {
+  pm.addNestedPass<FuncOp>(createTosaToSCF());
+}
diff --git a/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt b/mlir/lib/Conversion/TosaToStandard/CMakeLists.txt
new file mode 100644 (file)
index 0000000..43032f0
--- /dev/null
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRTosaToStandard
+  TosaToStandard.cpp
+  TosaToStandardPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRStandard
+  MLIRPass
+  MLIRTosa
+  MLIRTosaTransforms
+  MLIRSupport
+  )
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
new file mode 100644 (file)
index 0000000..21a8da2
--- /dev/null
@@ -0,0 +1,40 @@
+//===- TosaToStandard.cpp - Lowering Tosa to Standard Dialect -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These rewriters lower from the Tosa to the Standard dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+
+class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
+public:
+  using OpRewritePattern<tosa::ConstOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::ConstOp op,
+                                PatternRewriter &rewriter) const final {
+    rewriter.replaceOpWithNewOp<::ConstantOp>(op, op.value());
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaToStandardConversionPatterns(
+    MLIRContext *context, OwningRewritePatternList *patterns) {
+  patterns->insert<ConstOpConverter>(context);
+}
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
new file mode 100644 (file)
index 0000000..225855e
--- /dev/null
@@ -0,0 +1,52 @@
+//===- TosaToStandardPass.cpp - Lowering Tosa to Linalg Dialect -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes Tosa operations to the Standard dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "../PassDetail.h"
+#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
+public:
+  void runOnOperation() override {
+    OwningRewritePatternList patterns;
+    ConversionTarget target(getContext());
+    target.addIllegalOp<tosa::ConstOp>();
+    target.addLegalOp<ConstantOp>();
+
+    auto *op = getOperation();
+    mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(),
+                                                         &patterns);
+    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::tosa::createTosaToStandard() {
+  return std::make_unique<TosaToStandard>();
+}
+
+void mlir::tosa::addTosaToStandardPasses(OpPassManager &pm) {
+  pm.addNestedPass<FuncOp>(createTosaToStandard());
+}
diff --git a/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir b/mlir/test/Conversion/TosaToSCF/tosa-to-scf.mlir
new file mode 100644 (file)
index 0000000..82fa2c9
--- /dev/null
@@ -0,0 +1,58 @@
+// RUN: mlir-opt --split-input-file --tosa-to-scf %s -verify-diagnostics -o -| FileCheck %s
+
+// CHECK-LABEL: func @while_test
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>)
+func @while_test(%arg0 : tensor<i32>) -> (tensor<i32>) {
+  // CHECK: [[WHILE:%.+]] = scf.while ([[ARG1:%.+]] = [[ARG0]])
+  %1 = "tosa.while_loop"(%arg0) ( {
+  ^bb0(%arg2: tensor<i32>):
+    // CHECK: "tosa.const"
+    %2 = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+
+    // CHECK: [[COMPARE:%.+]] = "tosa.greater_equal"
+    %3 = "tosa.greater_equal"(%2, %arg2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+
+    // CHECK: [[EX:%.+]] = tensor.extract [[COMPARE]]
+    // CHECK: scf.condition([[EX]]) [[ARG1]]
+    "tosa.yield"(%3) : (tensor<i1>) -> ()
+  },  {
+  // CHECK: ^bb0([[ARG1:%.+]]: tensor<i32>)
+  ^bb0(%arg2: tensor<i32>):
+    // CHECK: tosa.const
+    %2 = "tosa.const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
+
+    // CHECK: [[ADD:%.+]] = "tosa.add"
+    %3 = "tosa.add"(%arg2, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+
+    // CHECK: scf.yield [[ADD]]
+    "tosa.yield"(%3) : (tensor<i32>) -> ()
+  }) : (tensor<i32>) -> (tensor<i32>)
+  return %1 : tensor<i32>
+}
+
+// ----
+
+// CHECK-LABEL: func @if_test
+// CHECK-SAME: ([[ARG0:%.+]]: tensor<f32>, [[ARG1:%.+]]: tensor<f32>, [[ARG2:%.+]]: tensor<i1>)
+func @if_test(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> (tensor<f32>) {
+  // CHECK: [[EX:%.+]] = tensor.extract [[ARG2]]
+  // CHECK: [[IF:%.+]] = scf.if [[EX]] -> (tensor<f32>) {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+
+  // CHECK:   scf.yield [[ARG0]]
+  ^bb1(%arg3 : tensor<f32>, %arg4 : tensor<f32>):
+    "tosa.yield"(%arg3) : (tensor<f32>) -> ()
+
+  // CHECK: } else {
+  }, {
+
+  // CHECK:   scf.yield [[ARG1]]
+  ^bb1(%arg5 : tensor<f32>, %arg6 : tensor<f32>):
+    "tosa.yield"(%arg6) : (tensor<f32>) -> ()
+
+  // CHECK: }
+  // CHECK: return [[IF]]
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> (tensor<f32>)
+
+  return %0 : tensor<f32>
+}
diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir
new file mode 100644 (file)
index 0000000..86304dc
--- /dev/null
@@ -0,0 +1,10 @@
+// RUN: mlir-opt --split-input-file --tosa-to-standard %s -verify-diagnostics -o -| FileCheck %s
+
+// CHECK-LABEL: func @const_test
+func @const_test() -> (tensor<i32>) {
+  // CHECK: [[C3:%.+]] = constant dense<3> : tensor<i32>
+  %0 = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+
+  // CHECK: return [[C3]]
+  return %0 : tensor<i32>
+}