[mlir] Add TransposeOp to Linalg structured ops.
authorOleg Shyshkov <shyshkov@google.com>
Wed, 19 Oct 2022 09:42:25 +0000 (11:42 +0200)
committerOleg Shyshkov <shyshkov@google.com>
Wed, 19 Oct 2022 10:27:52 +0000 (12:27 +0200)
RFC: https://discourse.llvm.org/t/rfc-primitive-ops-add-mapop-reductionop-transposeop-broadcastop-to-linalg/64184

Differential Revision: https://reviews.llvm.org/D135854

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir

index 28c75fc..e231bdd 100644 (file)
@@ -70,6 +70,10 @@ AffineMap extractOrIdentityMap(Optional<AffineMap> maybeMap, unsigned rank,
 SmallVector<AffineExpr, 4> concat(ArrayRef<AffineExpr> a,
                                   ArrayRef<AffineExpr> b);
 
+/// Check if `permutation` is a permutation of the range
+/// `[0, permutation.size())`.
+bool isPermutation(ArrayRef<int64_t> permutation);
+
 } // namespace linalg
 } // namespace mlir
 
index 4b83de1..9c2246e 100644 (file)
@@ -361,6 +361,78 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
 
 
 //===----------------------------------------------------------------------===//
+// Transpose op.
+//===----------------------------------------------------------------------===//
+
+def TransposeOp : LinalgStructuredBase_Op<"transpose", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    SameVariadicOperandSize,
+    SingleBlockImplicitTerminator<"YieldOp">]> {
+  let summary = "Transpose operator";
+  let description = [{
+    Permutes the dimensions of `input` according to the given `permutation`.
+      `dim(result, i) = dim(input, permutation[i])`
+
+    This op actually moves data, unlike `memref.transpose` which is a metadata
+    operation only that produces a transposed "view".
+
+    Example:
+    ```
+      %transpose = linalg.transpose
+          ins(%input:tensor<16x64xf32>)
+          outs(%init:tensor<64x16xf32>)
+          permutation = [1, 0]
+    ```
+  }];
+
+  let arguments = (ins
+    // Input arg
+    TensorOrMemref:$input,
+    // Output arg
+    TensorOrMemref:$init,
+
+    DenseI64ArrayAttr:$permutation
+  );
+  let results = (outs Variadic<AnyTensor>:$result);
+  let regions = (region SizedRegion<1>:$region);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "Value":$input, "Value":$init,
+        "DenseI64ArrayAttr":$permutation, CArg<"ArrayRef<NamedAttribute>",
+        "{}">:$attributes)>,
+    OpBuilder<(ins "Value":$input, "Value":$init,
+        "ArrayRef<int64_t>":$permutation, CArg<"ArrayRef<NamedAttribute>",
+        "{}">:$attributes)>,
+  ];
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+    // Declare functions necessary for LinalgStructuredInterface.
+    SmallVector<StringRef> getIteratorTypesArray();
+    ArrayAttr getIndexingMaps();
+    std::string getLibraryCallName() {
+      return "op_has_no_registered_library_name";
+    }
+
+    // Implement functions necessary for DestinationStyleOpInterface.
+    std::pair<int64_t, int64_t> getOutputsPositionRange() {
+      int64_t getNumOperands = this->getNumOperands();
+      return {getNumOperands - 1, getNumOperands};
+    }
+
+    static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+        mlir::ArrayRef<mlir::NamedAttribute>)>
+      getRegionBuilder();
+
+    static void createRegion(::mlir::OpBuilder &opBuilder,
+                             ::mlir::OperationState & odsState);
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
 // Named Linalg ops, implemented as a declarative configurations of generic ops.
 //===----------------------------------------------------------------------===//
 
index 305b859..6a10d43 100644 (file)
@@ -41,10 +41,6 @@ bool hasOnlyScalarElementwiseOp(Region &r);
 /// Check if a LinalgOp is an element-wise operation.
 bool isElementwise(LinalgOp op);
 
-/// Check if `permutation` is a permutation of the range
-/// `[0, permutation.size())`.
-bool isPermutation(ArrayRef<int64_t> permutation);
-
 /// Check if iterator type has "parallel" semantics.
 bool isParallelIterator(StringRef iteratorType);
 
index 2fcd21c..82e5024 100644 (file)
@@ -1602,6 +1602,142 @@ LogicalResult ReduceOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
+// TransposeOp
+//===----------------------------------------------------------------------===//
+
+std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+                   mlir::ArrayRef<mlir::NamedAttribute>)>
+TransposeOp::getRegionBuilder() {
+  return [](mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+            mlir::ArrayRef<mlir::NamedAttribute>) {
+    b.create<linalg::YieldOp>(block.getArguments().back());
+  };
+}
+
+void TransposeOp::createRegion(::mlir::OpBuilder &opBuilder,
+                               ::mlir::OperationState &odsState) {
+  Region *region = odsState.addRegion();
+
+  SmallVector<Type> argTypes;
+  SmallVector<Location> argLocs;
+  for (auto t : odsState.operands) {
+    argTypes.push_back(getElementTypeOrSelf(t));
+    argLocs.push_back(opBuilder.getUnknownLoc());
+  }
+
+  // RAII.
+  OpBuilder::InsertionGuard guard(opBuilder);
+  Block *body =
+      opBuilder.createBlock(region, /*insertPt=*/{}, argTypes, argLocs);
+
+  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
+  getRegionBuilder()(b, *body, odsState.attributes.getAttrs());
+}
+
+void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
+                        ::mlir::OperationState &odsState, Value input,
+                        Value init, DenseI64ArrayAttr permutation,
+                        ArrayRef<NamedAttribute> attributes) {
+  odsState.addOperands(input);
+  odsState.addOperands(init);
+  odsState.addAttribute(getPermutationAttrName(odsState.name), permutation);
+  odsState.addAttributes(attributes);
+  odsState.addTypes(init.getType());
+
+  createRegion(odsBuilder, odsState);
+}
+
+void TransposeOp::build(::mlir::OpBuilder &odsBuilder,
+                        ::mlir::OperationState &odsState, Value input,
+                        Value init, ArrayRef<int64_t> permutation,
+                        ArrayRef<NamedAttribute> attributes) {
+  build(odsBuilder, odsState, input, init,
+        odsBuilder.getDenseI64ArrayAttr(permutation), attributes);
+}
+
+ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
+  if (failed(parseDstStyleOp(
+          parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
+            return parseDenseI64ArrayAttr(parser, attributes, "permutation");
+          })))
+    return failure();
+
+  OpBuilder opBuilder(parser.getContext());
+  createRegion(opBuilder, result);
+  return success();
+}
+
+void TransposeOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  if (!getResults().empty())
+    setNameFn(getResults().front(), "transposed");
+}
+
+void TransposeOp::print(OpAsmPrinter &p) {
+  printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
+                               SmallVector<Value>(getOutputOperands()));
+  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
+  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
+}
+
+LogicalResult TransposeOp::verify() {
+  ArrayRef<int64_t> permutationRef = getPermutation();
+
+  if (!isPermutation(permutationRef))
+    return emitOpError("permutation is not valid");
+
+  auto inputType = getInput().getType();
+  auto initType = getInit().getType();
+
+  int64_t rank = inputType.getRank();
+
+  if (rank != initType.getRank())
+    return emitOpError() << "input rank " << rank
+                         << " does not match init rank " << initType.getRank();
+
+  if (rank != static_cast<int64_t>(permutationRef.size()))
+    return emitOpError() << "size of permutation " << permutationRef.size()
+                         << " does not match the argument rank " << rank;
+
+  auto inputDims = inputType.getShape();
+  auto initDims = initType.getShape();
+
+  for (int64_t i = 0; i < rank; ++i) {
+    int64_t inputDim = inputDims[permutationRef[i]];
+    int64_t initDim = initDims[i];
+
+    if (inputDim != initDim) {
+      return emitOpError() << "dim(result, " << i << ") = " << initDim
+                           << " doesn't match dim(input, permutation[" << i
+                           << "]) = " << inputDim;
+    }
+  }
+
+  return success();
+}
+
+SmallVector<StringRef> TransposeOp::getIteratorTypesArray() {
+  int64_t rank = getInit().getType().getRank();
+  return SmallVector<StringRef>(rank, getParallelIteratorTypeName());
+}
+
+ArrayAttr TransposeOp::getIndexingMaps() {
+  Builder builder(getContext());
+  int64_t rank = getInit().getType().getRank();
+  return builder.getAffineMapArrayAttr(
+      {builder.getMultiDimIdentityMap(rank),
+       AffineMap::getPermutationMap(
+           llvm::to_vector_of<unsigned>(getPermutation()), getContext())});
+}
+
+void TransposeOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  getGenericEffectsImpl(effects, getOperation()->getResults(),
+                        getInputOperands(), getOutputOperands());
+}
+
+//===----------------------------------------------------------------------===//
 // YieldOp
 //===----------------------------------------------------------------------===//
 
@@ -1710,6 +1846,19 @@ SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
   return llvm::to_vector<4>(concatRanges);
 }
 
+bool mlir::linalg::isPermutation(ArrayRef<int64_t> permutation) {
+  // Count the number of appearances for all indices.
+  SmallVector<int64_t> indexCounts(permutation.size(), 0);
+  for (auto index : permutation) {
+    // Exit if the index is out-of-range.
+    if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
+      return false;
+    ++indexCounts[index];
+  }
+  // Return true if all indices appear once.
+  return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
+}
+
 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
   if (auto memref = t.dyn_cast<MemRefType>()) {
     ss << "view";
index aba2d5f..af5a201 100644 (file)
@@ -186,19 +186,6 @@ bool isElementwise(LinalgOp op) {
   return hasOnlyScalarElementwiseOp(op->getRegion(0));
 }
 
-bool isPermutation(ArrayRef<int64_t> permutation) {
-  // Count the number of appearances for all indices.
-  SmallVector<int64_t> indexCounts(permutation.size(), 0);
-  for (auto index : permutation) {
-    // Exit if the index is out-of-range.
-    if (index < 0 || index >= static_cast<int64_t>(permutation.size()))
-      return false;
-    indexCounts[index]++;
-  }
-  // Return true if all indices appear once.
-  return count(indexCounts, 1) == static_cast<int64_t>(permutation.size());
-}
-
 bool isParallelIterator(StringRef iteratorType) {
   return iteratorType == getParallelIteratorTypeName();
 }
index 00352c4..e6ab837 100644 (file)
@@ -624,3 +624,52 @@ func.func @reduce_different_output_shapes(%input1: tensor<16x32x64xf32>,
       }
   func.return %reduce, %reduce2 : tensor<16x64xf32>, tensor<17x64xf32>
 }
+
+// -----
+
+func.func @transpose_invalid_permutation(%input: tensor<16x32x64xf32>,
+    %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+  // expected-error @+1 {{'linalg.transpose' op permutation is not valid}}
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<32x64x16xf32>)
+      permutation = [1, 1, 2]
+  func.return %transpose : tensor<32x64x16xf32>
+}
+
+// -----
+
+func.func @transpose_permutated_dims_mismatch(%input: tensor<16x32x64xf32>,
+    %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+  // expected-error @+1 {{'linalg.transpose' op dim(result, 0) = 32 doesn't match dim(input, permutation[0]) = 16}}
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<32x64x16xf32>)
+      permutation = [0, 1, 2]
+  func.return %transpose : tensor<32x64x16xf32>
+}
+
+// -----
+
+func.func @transpose_rank_permutation_size_mismatch(
+    %input: tensor<16x32x64xf32>,
+    %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+  // expected-error @+1 {{'linalg.transpose' op size of permutation 2 does not match the argument rank 3}}
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<32x64x16xf32>)
+      permutation = [1, 0]
+  func.return %transpose : tensor<32x64x16xf32>
+}
+
+// -----
+
+func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
+    %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+  // expected-error @+1 {{'linalg.transpose' op input rank 2 does not match init rank 3}}
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32xf32>)
+      outs(%init:tensor<32x64x16xf32>)
+      permutation = [1, 0, 2]
+  func.return %transpose : tensor<32x64x16xf32>
+}
index f751ddf..4bea3f6 100644 (file)
@@ -67,11 +67,11 @@ func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32)
 
 // -----
 
-func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
+func.func @memref_transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
   %0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
   return
 }
-// CHECK-LABEL: func @transpose
+// CHECK-LABEL: func @memref_transpose
 //       CHECK:   memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
 //  CHECK-SAME:      memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
 
@@ -457,3 +457,27 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
 }
 // CHECK-LABEL: func @variadic_reduce_memref
 //       CHECK:     linalg.reduce
+
+// -----
+
+func.func @transpose(%input: tensor<16x32x64xf32>,
+                     %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
+  %transpose = linalg.transpose
+      ins(%input:tensor<16x32x64xf32>)
+      outs(%init:tensor<32x64x16xf32>)
+      permutation = [1, 2, 0]
+  func.return %transpose : tensor<32x64x16xf32>
+}
+// CHECK-LABEL: func @transpose
+
+// -----
+
+func.func @transpose_memref(%input: memref<16x32x64xf32>,
+                            %init: memref<32x64x16xf32>) {
+  linalg.transpose
+      ins(%input:memref<16x32x64xf32>)
+      outs(%init:memref<32x64x16xf32>)
+      permutation = [1, 2, 0]
+  func.return
+}
+// CHECK-LABEL: func @transpose_memref