[mlir][Linalg] Add a softmax op
authorQuentin Colombet <quentin.colombet@gmail.com>
Thu, 29 Jun 2023 10:25:15 +0000 (12:25 +0200)
committerQuentin Colombet <quentin.colombet@gmail.com>
Thu, 29 Jun 2023 10:57:06 +0000 (12:57 +0200)
This patch adds a softmax op.
For now, nothing interesting happens, we can only do a round trip.
Later patches will add the tiling interface and the lowering of this op to
a sequence of simpler ops.

This is graduating the linag_ext.softmax op from iree to LLVM.

Original implementation from Harsh Menon <harsh@nod-labs.com>
Nicolas Vasilache <nicolas.vasilache@gmail.com> co-authored this patch.

Differential Revision: https://reviews.llvm.org/D153422

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir

index d9c1eec..43b86cd 100644 (file)
@@ -15,6 +15,7 @@
 
 include "mlir/Dialect/Linalg/IR/LinalgBase.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -87,4 +88,62 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
   let hasVerifier = 1;
 }
 
+def Linalg_SoftmaxOp : Linalg_Op<"softmax",
+    [DestinationStyleOpInterface,
+     PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+  let summary = "Softmax operator";
+  let description = [{
+    linalg.softmax computes a numerically stable version of softmax.
+
+    For a given input tensor and a specified dimension `d`, compute:
+      1. the max `m` along that dimension `d`
+      2. f(x) = exp(x - m)
+      3. sum f(x) along dimension d to get l(x).
+      4. compute the final result f(x) / l(x).
+
+    This is an aggregate linalg operation that further reduces to a small DAG of
+    structured operations.
+  }];
+
+  let arguments = (ins AnyShaped:$input,
+                       AnyShaped:$output,
+                       I64Attr:$dimension
+  );
+
+  let results = (outs Variadic<AnyRankedTensor>:$result);
+  let hasFolder = 1;
+  let assemblyFormat = [{
+    attr-dict
+    `dimension` `(` $dimension `)`
+    `ins` `(` $input `:` type($input) `)`
+    `outs` `(` $output `:` type($output) `)`
+    (`->` type($result)^)?
+  }];
+
+  let extraClassDeclaration = [{
+    ShapedType getInputOperandType() {
+      return getInput().getType().cast<ShapedType>();
+    }
+    ShapedType getOutputOperandType() {
+      return getOutput().getType().cast<ShapedType>();
+    }
+    int64_t getInputOperandRank() {
+      return getInputOperandType().getRank();
+    }
+    int64_t getOutputOperandRank() {
+      return getOutputOperandType().getRank();
+    }
+    // Method to implement DestinationStyleOpInterface.
+    std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+      std::pair<unsigned, unsigned> outputsIndexAndLength =
+        getODSOperandIndexAndLength(1);
+      return std::make_pair<int64_t, int64_t>(
+          outputsIndexAndLength.first,
+          outputsIndexAndLength.first + outputsIndexAndLength.second);
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // LINALG_OPS
index e5af203..03407e3 100644 (file)
@@ -2141,6 +2141,39 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
 // .cpp.inc.
 
 //===----------------------------------------------------------------------===//
+// SoftmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SoftmaxOp::verify() {
+  ShapedType inputType = getInputOperandType();
+  ShapedType outputType = getOutputOperandType();
+
+  ArrayRef<int64_t> inputShape = inputType.getShape();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  if (failed(verifyCompatibleShape(inputShape, outputShape)))
+    return emitOpError("incompatible output shape");
+
+  int64_t inputRank = getInputOperandRank();
+  int64_t dimension = getDimension();
+  if ((dimension < 0) || (dimension >= inputRank))
+    return emitOpError("incorrect dimension specified");
+
+  return success();
+}
+
+// cast(dynamic) -> static.
+LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+  return memref::foldMemRefCast(*this);
+}
+
+LogicalResult
+SoftmaxOp::reifyResultShapes(OpBuilder &b,
+                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  return cast<LinalgOp>(getOperation())
+      .reifyResultShapes(b, reifiedReturnShapes);
+}
+
+//===----------------------------------------------------------------------===//
 // LinalgDialect
 //===----------------------------------------------------------------------===//
 
index dbc93d5..88f070a 100644 (file)
@@ -733,3 +733,14 @@ func.func @missing_iterator_types() {
   linalg.generic {} ins() outs()
   return
 }
+
+// -----
+
+func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x16xf32> {
+  %0 = tensor.empty() : tensor<2x16xf32>
+  // expected-error @+1 {{incompatible output shape}}
+  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>)
+                                   outs(%0: tensor<2x16xf32>)
+    -> tensor<2x16xf32>
+  return %1 : tensor<2x16xf32>
+}
index 8bf5e5b..9895cd2 100644 (file)
@@ -599,3 +599,17 @@ func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
 // CHECK-SAME:    outs
 // CHECK-SAME:    dimensions = [1]
 // CHECK-NEXT:    return %[[REDUCED]] : tensor<16x64xf32>
+
+// -----
+
+func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+  %0 = tensor.empty() : tensor<2x16x32xf32>
+  %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+  return %1 : tensor<2x16x32xf32>
+}
+// CHECK:      func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK:        %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK:        %[[D1:.+]] = linalg.softmax dimension(2) ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D0]] :
+// CHECK-SAME:     tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+// CHECK:        return %[[D1]] : tensor<2x16x32xf32>
+// CHECK:      }