From cf9e88ff593f12bf01c3a6bc0e631bfb5f8bc0bd Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 13 Apr 2023 01:00:23 -0700 Subject: [PATCH] [mlir][Linalg] NFC - Extract an IndexingUtils from Linalg/Utils Differential Revision: https://reviews.llvm.org/D148201 --- .../mlir/Dialect/Linalg/Utils/IndexingUtils.h | 47 +++++++++++++ mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 12 +--- .../Linalg/Transforms/ElementwiseToLinalg.cpp | 2 +- mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt | 3 +- mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp | 82 ++++++++++++++++++++++ mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 30 -------- 6 files changed, 134 insertions(+), 42 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Linalg/Utils/IndexingUtils.h create mode 100644 mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Linalg/Utils/IndexingUtils.h new file mode 100644 index 0000000..f0ad07c --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Utils/IndexingUtils.h @@ -0,0 +1,47 @@ +//===- IndexingUtils.h - Indexing utilities supporting Linalg ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_UTILS_INDEXINGUTILS_H +#define MLIR_DIALECT_LINALG_UTILS_INDEXINGUTILS_H + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "llvm/ADT/StringSet.h" +#include + +namespace mlir { +namespace linalg { + +/// Create one memref::DimOp or tensor::DimOp depending on the type of `val`. +/// This is a polymorphic convenience function to abstract away the rank and +/// concrete type of `val`. +/// Asserts that `val` is a memref or tensor type. +Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim); + +/// Create one memref::DimOp or tensor::DimOp depending on the type of `val`. +/// This is a polymorphic convenience function to abstract away the rank and +/// concrete type of `val`. +/// Asserts that `val` is a memref or tensor type. +OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, + int64_t dim); + +/// Build the list of DimOp for the dynamic dimensions of `val`. +/// Asserts that `val` is a ranked shaped type. +SmallVector createDynamicDimensions(OpBuilder &b, Location loc, + Value val); + +/// Build the list of all dimensions for `val`, mixing static attributes and +/// dynamic values where appropriate. +/// Asserts that `val` is a ranked shaped type. +SmallVector getMixedDimensions(OpBuilder &b, Location loc, + Value val); + +} // namespace linalg +} // namespace mlir +#endif // MLIR_DIALECT_LINALG_UTILS_INDEXINGUTILS_H diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 4e42afa..828a062 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_LINALG_UTILS_UTILS_H #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/IndexingUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "llvm/ADT/StringSet.h" @@ -84,16 +85,6 @@ bool isParallelIterator(utils::IteratorType iteratorType); /// Check if iterator type has "reduction" semantics. bool isReductionIterator(utils::IteratorType iteratorType); -/// Helper function that creates a memref::DimOp or tensor::DimOp depending on -/// the type of `source`. -Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); -OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, - int64_t dim); - -/// Given an operation, retrieves the value of each dynamic dimension through -/// constructing the necessary DimOp operators. -SmallVector getDynOperands(Location loc, Value val, OpBuilder &b); - /// Create a tensor::PadOp that pads `source` to the size of the statically /// sized `type` whose static sizes are assumed to be greater than the dynamic /// `source` size. The padding introduces trailing `pad` values until the @@ -271,6 +262,7 @@ struct FusionInfo { /// transformation). FailureOr fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand); + /// Tensor counterpart of `fuseProducerOfBuffer`. /// This implements the fusion part of the "tileAndFuse on tensors" /// transformation and thus requires the `consumerOpOperand` to be a diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp index 52287f1..549764d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -69,7 +69,7 @@ getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) { Value firstOperand = operands.front(); auto rankedTensorType = t.cast(); auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape()); - auto dynamicShape = linalg::getDynOperands(loc, firstOperand, b); + auto dynamicShape = linalg::createDynamicDimensions(b, loc, firstOperand); res.push_back(b.create( loc, staticShape, rankedTensorType.getElementType(), dynamicShape)); diff --git a/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt index da0a3c2..c34820b 100644 --- a/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRLinalgUtils Utils.cpp + IndexingUtils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg @@ -16,4 +17,4 @@ add_mlir_dialect_library(MLIRLinalgUtils MLIRPass MLIRTensorUtils MLIRTransformUtils - ) +) diff --git a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp new file mode 100644 index 0000000..12b55ef --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp @@ -0,0 +1,82 @@ +//===- IndexingUtils.cpp - Indexing utilities supporting Linalg -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements indexing utilities for the Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Utils/Utils.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "linalg-utils" + +namespace mlir { +namespace linalg { +Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim) { + if (val.getType().isa()) + return b.createOrFold(loc, val, dim); + if (val.getType().isa()) + return b.createOrFold(loc, val, dim); + llvm_unreachable("Expected MemRefType or TensorType"); +} + +OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, + int64_t dim) { + auto shapedType = val.getType().cast(); + if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) + return createOrFoldDimOp(b, loc, val, dim); + return b.getIndexAttr(shapedType.getDimSize(dim)); +} + +SmallVector createDynamicDimensions(OpBuilder &b, Location loc, + Value val) { + auto shapedType = val.getType().cast(); + assert(shapedType.hasRank() && "`val` must have a static rank"); + SmallVector res; + res.reserve(shapedType.getRank()); + for (const auto &dim : llvm::enumerate(shapedType.getShape())) { + if (dim.value() == ShapedType::kDynamic) + res.push_back(createOrFoldDimOp(b, loc, val, dim.index())); + } + return res; +} + +SmallVector getMixedDimensions(OpBuilder &b, Location loc, + Value val) { + auto shapedType = val.getType().cast(); + assert(shapedType.hasRank() && "`val` must have a static rank"); + SmallVector dynamicDims = createDynamicDimensions(b, loc, val); + return getMixedValues(shapedType.getShape(), dynamicDims, b); +} +} // namespace linalg +} // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 84fd14b..4c8f31b 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -267,36 +267,6 @@ bool isReductionIterator(utils::IteratorType iteratorType) { return iteratorType == utils::IteratorType::reduction; } -/// Helper function that creates a memref::DimOp or tensor::DimOp depending on -/// the type of `source`. -Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { - if (source.getType().isa()) - return b.createOrFold(loc, source, dim); - if (source.getType().isa()) - return b.createOrFold(loc, source, dim); - llvm_unreachable("Expected MemRefType or TensorType"); -} - -OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, - int64_t dim) { - auto shapedType = source.getType().cast(); - if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) - return createOrFoldDimOp(b, loc, source, dim); - return b.getIndexAttr(shapedType.getDimSize(dim)); -} - -/// Given an operation, retrieves the value of each dynamic dimension through -/// constructing the necessary DimOp operators. -SmallVector getDynOperands(Location loc, Value val, OpBuilder &b) { - SmallVector dynOperands; - auto shapedType = val.getType().cast(); - for (const auto &dim : llvm::enumerate(shapedType.getShape())) { - if (dim.value() == ShapedType::kDynamic) - dynOperands.push_back(createOrFoldDimOp(b, loc, val, dim.index())); - } - return dynOperands; -} - Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value pad, bool nofold) { // Exit if `source` is not defined by an ExtractSliceOp. -- 2.7.4