[include "QuantPasses.md"]
-## `shape` Dialect Passes
-
-[include "ShapePasses.md"]
-
## `spv` Dialect Passes
[include "SPIRVPasses.md"]
add_subdirectory(IR)
-add_subdirectory(Transforms)
let arguments = (ins Shape_ShapeType:$shape, Variadic<AnyType>:$initVals);
let results = (outs Variadic<AnyType>:$result);
- let regions = (region SizedRegion<1>:$region);
+ let regions = (region SizedRegion<1>:$body);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, "
+++ /dev/null
-set(LLVM_TARGET_DEFINITIONS Passes.td)
-mlir_tablegen(Passes.h.inc -gen-pass-decls)
-add_public_tablegen_target(MLIRShapeTransformsIncGen)
-
-add_mlir_doc(Passes -gen-pass-doc ShapePasses ./)
+++ /dev/null
-//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This header file defines prototypes that expose pass constructors in the
-// shape transformation library.
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
-#define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
-
-#include <memory>
-
-namespace mlir {
-
-class Pass;
-
-/// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape
-/// dialect to be convertible to Standard. For example, `shape.num_elements` get
-/// transformed to `shape.reduce`, which can be lowered to SCF and Standard.
-std::unique_ptr<Pass> createShapeToShapeLowering();
-
-} // end namespace mlir
-
-#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
+++ /dev/null
-//===-- Passes.td - ShapeOps pass definition file ----------*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
-#define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
-
-include "mlir/Pass/PassBase.td"
-
-def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> {
- let summary = "Legalize Shape dialect to be convertible to Standard";
- let constructor = "mlir::createShapeToShapeLowering()";
-}
-
-#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
#include "mlir/Dialect/Quant/Passes.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/SPIRV/Passes.h"
-#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Transforms/LocationSnapshot.h"
#include "mlir/Transforms/Passes.h"
// Standard
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc"
-
- // Shape
-#define GEN_PASS_REGISTRATION
-#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
}
} // namespace mlir
MLIRIR
MLIRSideEffectInterfaces
)
-
-add_subdirectory(Transforms)
static LogicalResult verify(ReduceOp op) {
// Verify block arg types.
- Block &block = op.region().front();
+ Block &block = op.body().front();
auto blockArgsCount = op.initVals().size() + 2;
if (block.getNumArguments() != blockArgsCount)
p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals()
<< ") ";
p.printOptionalArrowTypeList(op.getResultTypes());
- p.printRegion(op.region());
+ p.printRegion(op.body());
p.printOptionalAttrDict(op.getAttrs());
}
+++ /dev/null
-add_mlir_dialect_library(MLIRShapeOpsTransforms
- ShapeToShapeLowering.cpp
-
- ADDITIONAL_HEADER_DIRS
- ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms
-
- DEPENDS
- MLIRShapeTransformsIncGen
- )
-
-target_link_libraries(MLIRShapeOpsTransforms
- PUBLIC
- MLIRIR
- MLIRPass
- MLIRShape
- MLIRSupport
- )
+++ /dev/null
-//===- PassDetail.h - Shape Pass class details ------------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_
-#define DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_
-
-#include "mlir/Pass/Pass.h"
-
-namespace mlir {
-
-#define GEN_PASS_CLASSES
-#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
-
-} // end namespace mlir
-
-#endif // DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_
+++ /dev/null
-//===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "PassDetail.h"
-#include "mlir/Dialect/Shape/IR/Shape.h"
-#include "mlir/Dialect/Shape/Transforms/Passes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-using namespace mlir;
-using namespace mlir::shape;
-
-namespace {
-/// Converts `shape.num_elements` to `shape.reduce`.
-struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(NumElementsOp op,
- PatternRewriter &rewriter) const final;
-};
-} // namespace
-
-LogicalResult
-NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
- PatternRewriter &rewriter) const {
- auto loc = op.getLoc();
- Value init = rewriter.create<ConstSizeOp>(loc, rewriter.getIndexAttr(1));
- ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.shape(), init);
-
- // Generate reduce operator.
- Block *body = reduce.getBody();
- OpBuilder b = OpBuilder::atBlockEnd(body);
- Value product =
- b.create<MulOp>(loc, body->getArgument(1), body->getArgument(2));
- b.create<YieldOp>(loc, product);
-
- rewriter.replaceOp(op, reduce.result());
- return success();
-}
-
-namespace {
-struct ShapeToShapeLowering
- : public ShapeToShapeLoweringBase<ShapeToShapeLowering> {
- void runOnFunction() override;
-};
-} // namespace
-
-void ShapeToShapeLowering::runOnFunction() {
- OwningRewritePatternList patterns;
- patterns.insert<NumElementsOpConverter>(&getContext());
-
- ConversionTarget target(getContext());
- target.addLegalDialect<ShapeDialect>();
- target.addIllegalOp<NumElementsOp>();
- if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
- signalPassFailure();
-}
-
-std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
- return std::make_unique<ShapeToShapeLowering>();
-}
+++ /dev/null
-// RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s --dump-input-on-failure
-
-// CHECK-LABEL: func @num_elements_to_reduce(
-// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> [[SIZE_TY:!.*]] {
-func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
- %num_elements = shape.num_elements %shape
- return %num_elements : !shape.size
-}
-// CHECK: [[C1:%.*]] = shape.const_size 1
-// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) -> [[SIZE_TY]]
-// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: [[SIZE_TY]], [[ACC:%.*]]: [[SIZE_TY]]
-// CHECK: [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]]
-// CHECK: shape.yield [[NEW_ACC]] : [[SIZE_TY]]
-// CHECK: }
-// CHECK: return [[NUM_ELEMENTS]] : [[SIZE_TY]]
-