Add a linalg.transpose op
authorNicolas Vasilache <ntv@google.com>
Fri, 23 Aug 2019 21:47:46 +0000 (14:47 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 23 Aug 2019 21:48:13 +0000 (14:48 -0700)
A linalg.transpose op is a pure metadata operation that takes a view + permutation map and produces
another view of the same underlying data, with a different reindexing. This is a
pure metadata operation that does not touch the underlying data.

Example:

```
  %t = linalg.transpose %v (i, j) -> (j, i) : !linalg.view<?x?xf32>
```

PiperOrigin-RevId: 265139429

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Linalg/invalid.mlir
mlir/test/Linalg/roundtrip.mlir

index ebd63bd..235817f 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
-include "mlir/Dialect/Linalg/IR/LinalgBase.td"
-
 #ifdef LINALG_OPS
 #else
 #define LINALG_OPS
 
+include "mlir/Dialect/AffineOps/AffineOpsBase.td"
+include "mlir/Dialect/Linalg/IR/LinalgBase.td"
+
 // Base class for Linalg dialect ops that do not correspond to library calls.
 class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
     Op<Linalg_Dialect, mnemonic, traits> {
@@ -331,10 +332,10 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
     base view. This allows defining a subregion within the underlying buffer.
 
     The "linalg.subview" operation takes a base view, a list of indices and
-    returns a new linalg.view of the same type that is contained within the 
-    view. This operation is equivalent to a non-rank-reducing slice operation. 
-    The main difference is the operands are all of type `index` and no 
-    intermediate linalg.range operations are required. A "linalg.subview" is 
+    returns a new linalg.view of the same type that is contained within the
+    view. This operation is equivalent to a non-rank-reducing slice operation.
+    The main difference is the operands are all of type `index` and no
+    intermediate linalg.range operations are required. A "linalg.subview" is
     thus a specialized linalg.slice with a higher level of abstraction.
 
     Similary to linalg.slice, if a range extends past the size of the base view,
@@ -398,6 +399,37 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>,
   }];
 }
 
+def TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>,
+    Arguments<(ins View:$view, AffineMapAttr:$permutation)>,
+    Results<(outs View)> {
+  let summary = "transpose operation produces a new view (metadata-only)";
+  let description = [{
+    The "linalg.transpose" op produces a linalg.view whose sizes and strides are
+    a permutation of the original. This is a pure metadata transformation.
+
+    Example:
+
+       %1 = linalg.transpose %0 (i, j) -> (j, i) : !linalg.view<?x?xf32>
+  }];
+
+  let builders = [OpBuilder<
+    "Builder *b, OperationState *result, Value *view, "
+    "AffineMapAttr permutation, ArrayRef<NamedAttribute> attrs = {}">];
+
+  let verifier = [{
+    if (!permutation().isPermutation())
+      return emitOpError("expected a permutation map");
+    if (permutation().getNumDims() != getViewType().getRank())
+      return emitOpError("expected a permutation map of same rank as the view");
+    return success();
+  }];
+
+  let extraClassDeclaration = [{
+    static StringRef getPermutationAttrName() { return "permutation"; }
+    ViewType getViewType() { return view()->getType().cast<ViewType>(); }
+  }];
+}
+
 def ViewOp : Linalg_Op<"view", [NoSideEffect]>,
     Arguments<(ins Buffer:$buffer, Variadic<Range>:$ranges)>,
     Results<(outs View)> {
index 881f72e..b5bbb59 100644 (file)
@@ -20,6 +20,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/LoopOps/LoopOps.h"
 #include "mlir/EDSC/Helpers.h"
 #include "mlir/IR/AffineExpr.h"
@@ -30,8 +32,6 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
-#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
-#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Support/STLExtras.h"
 #include "mlir/Transforms/FoldUtils.h"
@@ -600,6 +600,39 @@ static ParseResult parseSubViewOp(OpAsmParser *parser, OperationState *result) {
 }
 
 //===----------------------------------------------------------------------===//
+// TransposeOp
+//===----------------------------------------------------------------------===//
+void mlir::linalg::TransposeOp::build(Builder *b, OperationState *result,
+                                      Value *view, AffineMapAttr permutation,
+                                      ArrayRef<NamedAttribute> attrs) {
+  // TODO(ntv): once views have static dimensions, compute the permuted type.
+  build(b, result, view->getType(), view, attrs);
+  result->addAttribute(TransposeOp::getPermutationAttrName(), permutation);
+}
+
+static void print(OpAsmPrinter *p, TransposeOp op) {
+  *p << op.getOperationName() << " " << *op.view() << " " << op.permutation();
+  p->printOptionalAttrDict(op.getAttrs(),
+                           {TransposeOp::getPermutationAttrName()});
+  *p << " : " << op.view()->getType();
+}
+
+static ParseResult parseTransposeOp(OpAsmParser *parser,
+                                    OperationState *result) {
+  OpAsmParser::OperandType view;
+  AffineMapAttr permutation;
+  Type type;
+  return failure(parser->parseOperand(view) ||
+                 parser->parseAttribute(permutation,
+                                        TransposeOp::getPermutationAttrName(),
+                                        result->attributes) ||
+                 parser->parseOptionalAttributeDict(result->attributes) ||
+                 parser->parseColonType(type) ||
+                 parser->resolveOperand(view, type, result->operands) ||
+                 parser->addTypeToList(type, result->types));
+}
+
+//===----------------------------------------------------------------------===//
 // ViewOp
 //===----------------------------------------------------------------------===//
 void mlir::linalg::ViewOp::build(Builder *b, OperationState *result,
index d720c53..9ea9e51 100644 (file)
@@ -85,6 +85,20 @@ func @subview_number_of_indices(%v : !linalg.view<?x?xf32>) {
 
 // -----
 
+func @transpose_not_permutation(%v : !linalg.view<?x?xf32>) {
+  // expected-error @+1 {{expected a permutation map}}
+  linalg.transpose %v (i, j) -> (i, i) : !linalg.view<?x?xf32>
+}
+
+// -----
+
+func @transpose_bad_rank(%v : !linalg.view<?x?xf32>) {
+  // expected-error @+1 {{expected a permutation map of same rank as the view}}
+  linalg.transpose %v (i) -> (i) : !linalg.view<?x?xf32>
+}
+
+// -----
+
 func @view_type(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %step: index) {
   // expected-error @+2 {{expected view type}}
   %r = linalg.range %min:%max:%step : !linalg.range
index a42c8c1..eefa409 100644 (file)
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s | mlir-opt | FileCheck %s
 
-// CHECK: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1)
-// CHECK: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0)
+// CHECK-DAG: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1)
+// CHECK-DAG: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0)
 
 func @range(%arg0: index, %arg1: index, %arg2: index) {
   %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
@@ -121,6 +121,13 @@ func @fill_view(%arg0: !linalg.view<?xf32>, %arg1: f32) {
 // CHECK-LABEL: func @fill_view(%{{.*}}: !linalg.view<?xf32>, %{{.*}}: f32) {
 //       CHECK:   linalg.fill(%{{.*}}, %{{.*}}) : !linalg.view<?xf32>, f32
 
+func @transpose(%arg0: !linalg.view<?x?x?xf32>) {
+  %0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : !linalg.view<?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @transpose
+//       CHECK:   linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : !linalg.view<?x?x?xf32>
+
 func @fill_view3(%arg0: !linalg.view<?x?x?xf32>, %arg1: f32) {
   linalg.fill(%arg0, %arg1) : !linalg.view<?x?x?xf32>, f32
   return