include "mlir/Dialect/Linalg/IR/LinalgBase.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
let hasVerifier = 1;
}
+def Linalg_SoftmaxOp : Linalg_Op<"softmax",
+ [DestinationStyleOpInterface,
+ PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+ let summary = "Softmax operator";
+ let description = [{
+ linalg.softmax computes a numerically stable version of softmax.
+
+ For a given input tensor and a specified dimension `d`, compute:
+ 1. the max `m` along that dimension `d`
+ 2. f(x) = exp(x - m)
+ 3. sum f(x) along dimension d to get l(x).
+ 4. compute the final result f(x) / l(x).
+
+ This is an aggregate linalg operation that further reduces to a small DAG of
+ structured operations.
+ }];
+
+ let arguments = (ins AnyShaped:$input,
+ AnyShaped:$output,
+ I64Attr:$dimension
+ );
+
+ let results = (outs Variadic<AnyRankedTensor>:$result);
+ let hasFolder = 1;
+ let assemblyFormat = [{
+ attr-dict
+ `dimension` `(` $dimension `)`
+ `ins` `(` $input `:` type($input) `)`
+ `outs` `(` $output `:` type($output) `)`
+ (`->` type($result)^)?
+ }];
+
+ let extraClassDeclaration = [{
+ ShapedType getInputOperandType() {
+ return getInput().getType().cast<ShapedType>();
+ }
+ ShapedType getOutputOperandType() {
+ return getOutput().getType().cast<ShapedType>();
+ }
+ int64_t getInputOperandRank() {
+ return getInputOperandType().getRank();
+ }
+ int64_t getOutputOperandRank() {
+ return getOutputOperandType().getRank();
+ }
+ // Method to implement DestinationStyleOpInterface.
+ std::pair<int64_t, int64_t> getDpsInitsPositionRange() {
+ std::pair<unsigned, unsigned> outputsIndexAndLength =
+ getODSOperandIndexAndLength(1);
+ return std::make_pair<int64_t, int64_t>(
+ outputsIndexAndLength.first,
+ outputsIndexAndLength.first + outputsIndexAndLength.second);
+ }
+ }];
+ let hasVerifier = 1;
+}
+
#endif // LINALG_OPS
// .cpp.inc.
//===----------------------------------------------------------------------===//
+// SoftmaxOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SoftmaxOp::verify() {
+ ShapedType inputType = getInputOperandType();
+ ShapedType outputType = getOutputOperandType();
+
+ ArrayRef<int64_t> inputShape = inputType.getShape();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ if (failed(verifyCompatibleShape(inputShape, outputShape)))
+ return emitOpError("incompatible output shape");
+
+ int64_t inputRank = getInputOperandRank();
+ int64_t dimension = getDimension();
+ if ((dimension < 0) || (dimension >= inputRank))
+ return emitOpError("incorrect dimension specified");
+
+ return success();
+}
+
+// cast(dynamic) -> static.
+LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
+ return memref::foldMemRefCast(*this);
+}
+
+LogicalResult
+SoftmaxOp::reifyResultShapes(OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ return cast<LinalgOp>(getOperation())
+ .reifyResultShapes(b, reifiedReturnShapes);
+}
+
+//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
// CHECK-SAME: outs
// CHECK-SAME: dimensions = [1]
// CHECK-NEXT: return %[[REDUCED]] : tensor<16x64xf32>
+
+// -----
+
+func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+ %0 = tensor.empty() : tensor<2x16x32xf32>
+ %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+ return %1 : tensor<2x16x32xf32>
+}
+// CHECK: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32>
+// CHECK: %[[D1:.+]] = linalg.softmax dimension(2) ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D0]] :
+// CHECK-SAME: tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
+// CHECK: return %[[D1]] : tensor<2x16x32xf32>
+// CHECK: }