[MLIR] Add async token/value arguments to async.execute op
authorEugene Zhulenev <ezhulenev@google.com>
Thu, 8 Oct 2020 20:28:09 +0000 (13:28 -0700)
committerEugene Zhulenev <ezhulenev@google.com>
Fri, 9 Oct 2020 15:52:27 +0000 (08:52 -0700)
Async execute operation can take async arguments as dependencies.

Change `async.execute` custom parser/printer format to use `%value as %unwrapped: !async.value<!type>` sytax.

Reviewed By: mehdi_amini, herhut

Differential Revision: https://reviews.llvm.org/D88601

mlir/include/mlir/Dialect/Async/IR/Async.h
mlir/include/mlir/Dialect/Async/IR/AsyncBase.td
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/test/Dialect/Async/ops.mlir

index b1cf25e..1519ccd 100644 (file)
 #ifndef MLIR_DIALECT_ASYNC_IR_ASYNC_H
 #define MLIR_DIALECT_ASYNC_IR_ASYNC_H
 
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 namespace mlir {
index 2097f05..c411c4a 100644 (file)
@@ -56,7 +56,11 @@ class Async_ValueType<Type type>
   Type valueType = type;
 }
 
-def Async_AnyValueType : Type<CPred<"$_self.isa<::mlir::async::ValueType>()">,
-                                    "async value type">;
+def Async_AnyValueType : DialectType<AsyncDialect,
+                           CPred<"$_self.isa<::mlir::async::ValueType>()">,
+                                 "async value type">;
+
+def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType,
+                                           Async_TokenType]>;
 
 #endif // ASYNC_BASE_TD
index 2dcc9a8..fbdbdb9 100644 (file)
@@ -24,7 +24,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 class Async_Op<string mnemonic, list<OpTrait> traits = []> :
     Op<AsyncDialect, mnemonic, traits>;
 
-def Async_ExecuteOp : Async_Op<"execute"> {
+def Async_ExecuteOp : Async_Op<"execute", [AttrSizedOperandSegments]> {
   let summary = "Asynchronous execute operation";
   let description = [{
     The `body` region attached to the `async.execute` operation semantically
@@ -40,24 +40,43 @@ def Async_ExecuteOp : Async_Op<"execute"> {
     state). All dependencies must be made explicit with async execute arguments
     (`async.token` or `async.value`).
 
+   `async.execute` operation takes `async.token` dependencies and `async.value`
+    operands separatly, and starts execution of the attached body region only
+    when all tokens and values become ready.
+
+    Example:
+
     ```mlir
-    %done, %values = async.execute {
-      %0 = "compute0"(...) : !some.type
-      async.yield %1 : f32
-    } : !async.token, !async.value<!some.type>
+    %dependency = ... : !async.token
+    %value = ... : !async.value<f32>
+
+    %token, %results =
+      async.execute [%dependency](%value as %unwrapped: !async.value<f32>)
+                 -> !async.value<!some.type>
+      {
+        %0 = "compute0"(%unwrapped): (f32) -> !some.type
+        async.yield %0 : !some.type
+      }
 
     %1 = "compute1"(...) : !some.type
     ```
+
+    In the example above asynchronous execution starts only after dependency
+    token and value argument become ready. Unwrapped value passed to the
+    attached body region as an %unwrapped value of f32 type.
   }];
 
-  // TODO: Take async.tokens/async.values as arguments.
-  let arguments = (ins );
-  let results = (outs Async_TokenType:$done,
-                      Variadic<Async_AnyValueType>:$values);
+  let arguments = (ins Variadic<Async_TokenType>:$dependencies,
+                       Variadic<Async_AnyValueOrTokenType>:$operands);
+
+  let results = (outs Async_TokenType:$token,
+                      Variadic<Async_AnyValueType>:$results);
   let regions = (region SizedRegion<1>:$body);
 
-  let printer = [{ return ::mlir::async::print(p, *this); }];
-  let parser = [{ return ::mlir::async::parse$cppClass(parser, result); }];
+  let printer = [{ return ::print(p, *this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+
+  let verifier = [{ return ::verify(*this); }];
 }
 
 def Async_YieldOp :
@@ -72,7 +91,7 @@ def Async_YieldOp :
 
   let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
 
-  let verifier = [{ return ::mlir::async::verify(*this); }];
+  let verifier = [{ return ::verify(*this); }];
 }
 
 #endif // ASYNC_OPS
index 4d9ede1..eb5e65b 100644 (file)
@@ -8,19 +8,11 @@
 
 #include "mlir/Dialect/Async/IR/Async.h"
 
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
-#include "mlir/Dialect/Traits.h"
-#include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/Transforms/InliningUtils.h"
-#include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/raw_ostream.h"
 
-namespace mlir {
-namespace async {
+using namespace mlir;
+using namespace mlir::async;
 
 void AsyncDialect::initialize() {
   addOperations<
@@ -69,6 +61,8 @@ void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
 /// ValueType
 //===----------------------------------------------------------------------===//
 
+namespace mlir {
+namespace async {
 namespace detail {
 
 // Storage for `async.value<T>` type, the only member is the wrapped type.
@@ -90,6 +84,8 @@ struct ValueTypeStorage : public TypeStorage {
 };
 
 } // namespace detail
+} // namespace async
+} // namespace mlir
 
 ValueType ValueType::get(Type valueType) {
   return Base::get(valueType.getContext(), valueType);
@@ -105,7 +101,7 @@ static LogicalResult verify(YieldOp op) {
   // Get the underlying value types from async values returned from the
   // parent `async.execute` operation.
   auto executeOp = op.getParentOfType<ExecuteOp>();
-  auto types = llvm::map_range(executeOp.values(), [](const OpResult &result) {
+  auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
     return result.getType().cast<ValueType>().getValueType();
   });
 
@@ -120,49 +116,139 @@ static LogicalResult verify(YieldOp op) {
 /// ExecuteOp
 //===----------------------------------------------------------------------===//
 
+constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
+
 static void print(OpAsmPrinter &p, ExecuteOp op) {
-  p << "async.execute ";
-  p.printRegion(op.body());
-  p.printOptionalAttrDict(op.getAttrs());
-  p << " : ";
-  p.printType(op.done().getType());
-  if (!op.values().empty())
-    p << ", ";
-  llvm::interleaveComma(op.values(), p, [&](const OpResult &result) {
-    p.printType(result.getType());
-  });
+  p << op.getOperationName();
+
+  // [%tokens,...]
+  if (!op.dependencies().empty())
+    p << " [" << op.dependencies() << "]";
+
+  // (%value as %unwrapped: !async.value<!arg.type>, ...)
+  if (!op.operands().empty()) {
+    p << " (";
+    llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
+      p << operand << " as " << op.body().front().getArgument(n++) << ": "
+        << operand.getType();
+    });
+    p << ")";
+  }
+
+  // -> (!async.value<!return.type>, ...)
+  p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1));
+  p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr});
+  p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
 }
 
 static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
   MLIRContext *ctx = result.getContext();
 
-  // Parse asynchronous region.
-  Region *body = result.addRegion();
-  if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{},
-                         /*enableNameShadowing=*/false))
+  // Sizes of parsed variadic operands, will be updated below after parsing.
+  int32_t numDependencies = 0;
+  int32_t numOperands = 0;
+
+  auto tokenTy = TokenType::get(ctx);
+
+  // Parse dependency tokens.
+  if (succeeded(parser.parseOptionalLSquare())) {
+    SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
+    if (parser.parseOperandList(tokenArgs) ||
+        parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
+        parser.parseRSquare())
+      return failure();
+
+    numDependencies = tokenArgs.size();
+  }
+
+  // Parse async value operands (%value as %unwrapped : !async.value<!type>).
+  SmallVector<OpAsmParser::OperandType, 4> valueArgs;
+  SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
+  SmallVector<Type, 4> valueTypes;
+  SmallVector<Type, 4> unwrappedTypes;
+
+  if (succeeded(parser.parseOptionalLParen())) {
+    auto argsLoc = parser.getCurrentLocation();
+
+    // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
+    auto parseAsyncValueArg = [&]() -> ParseResult {
+      if (parser.parseOperand(valueArgs.emplace_back()) ||
+          parser.parseKeyword("as") ||
+          parser.parseOperand(unwrappedArgs.emplace_back()) ||
+          parser.parseColonType(valueTypes.emplace_back()))
+        return failure();
+
+      auto valueTy = valueTypes.back().dyn_cast<ValueType>();
+      unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
+
+      return success();
+    };
+
+    // If the next token is `)` skip async value arguments parsing.
+    if (failed(parser.parseOptionalRParen())) {
+      do {
+        if (parseAsyncValueArg())
+          return failure();
+      } while (succeeded(parser.parseOptionalComma()));
+
+      if (parser.parseRParen() ||
+          parser.resolveOperands(valueArgs, valueTypes, argsLoc,
+                                 result.operands))
+        return failure();
+    }
+
+    numOperands = valueArgs.size();
+  }
+
+  // Add derived `operand_segment_sizes` attribute based on parsed operands.
+  auto operandSegmentSizes = DenseIntElementsAttr::get(
+      VectorType::get({2}, parser.getBuilder().getI32Type()),
+      {numDependencies, numOperands});
+  result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
+
+  // Parse the types of results returned from the async execute op.
+  SmallVector<Type, 4> resultTypes;
+  if (parser.parseOptionalArrowTypeList(resultTypes))
     return failure();
 
+  // Async execute first result is always a completion token.
+  parser.addTypeToList(tokenTy, result.types);
+  parser.addTypesToList(resultTypes, result.types);
+
   // Parse operation attributes.
   NamedAttrList attrs;
-  if (parser.parseOptionalAttrDict(attrs))
+  if (parser.parseOptionalAttrDictWithKeyword(attrs))
     return failure();
   result.addAttributes(attrs);
 
-  // Parse result types.
-  SmallVector<Type, 4> resultTypes;
-  if (parser.parseColonTypeList(resultTypes))
-    return failure();
-
-  // First result type must be an async token type.
-  if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx))
+  // Parse asynchronous region.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
+                         /*argTypes=*/{unwrappedTypes},
+                         /*enableNameShadowing=*/false))
     return failure();
-  parser.addTypesToList(resultTypes, result.types);
 
   return success();
 }
 
-} // namespace async
-} // namespace mlir
+static LogicalResult verify(ExecuteOp op) {
+  // Unwrap async.execute value operands types.
+  auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
+    return operand.getType().cast<ValueType>().getValueType();
+  });
+
+  // Verify that unwrapped argument types matches the body region arguments.
+  if (llvm::size(unwrappedTypes) != llvm::size(op.body().getArgumentTypes()))
+    return op.emitOpError("the number of async body region arguments does not "
+                          "match the number of execute operation arguments");
+
+  if (!std::equal(unwrappedTypes.begin(), unwrappedTypes.end(),
+                  op.body().getArgumentTypes().begin()))
+    return op.emitOpError("async body region argument types do not match the "
+                          "execute operation arguments types");
+
+  return success();
+}
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
index d23bc00..371cea7 100644 (file)
@@ -1,7 +1,7 @@
-// RUN: mlir-opt  %s | FileCheck %s
+// RUN: mlir-opt %s | FileCheck %s
 
 // CHECK-LABEL: @identity_token
-func @identity_token(%arg0 : !async.token) -> !async.token {
+func @identity_token(%arg0: !async.token) -> !async.token {
   // CHECK: return %arg0 : !async.token
   return %arg0 : !async.token
 }
@@ -14,33 +14,95 @@ func @identity_value(%arg0 : !async.value<f32>) -> !async.value<f32> {
 
 // CHECK-LABEL: @empty_async_execute
 func @empty_async_execute() -> !async.token {
-  %done = async.execute {
+  // CHECK: async.execute
+  %token = async.execute {
     async.yield
-  } : !async.token
+  }
 
-  // CHECK: return %done : !async.token
-  return %done : !async.token
+  // CHECK: return %token : !async.token
+  return %token : !async.token
 }
 
 // CHECK-LABEL: @return_async_value
 func @return_async_value() -> !async.value<f32> {
-  %done, %values = async.execute {
+  // CHECK: async.execute -> !async.value<f32>
+  %token, %results = async.execute -> !async.value<f32> {
     %cst = constant 1.000000e+00 : f32
     async.yield %cst : f32
-  } : !async.token, !async.value<f32>
+  }
 
-  // CHECK: return %values : !async.value<f32>
-  return %values : !async.value<f32>
+  // CHECK: return %results : !async.value<f32>
+  return %results : !async.value<f32>
+}
+
+// CHECK-LABEL: @return_captured_value
+func @return_captured_value() -> !async.token {
+  %cst = constant 1.000000e+00 : f32
+  // CHECK: async.execute -> !async.value<f32>
+  %token, %results = async.execute -> !async.value<f32> {
+    async.yield %cst : f32
+  }
+
+  // CHECK: return %token : !async.token
+  return %token : !async.token
 }
 
 // CHECK-LABEL: @return_async_values
 func @return_async_values() -> (!async.value<f32>, !async.value<f32>) {
-  %done, %values:2 = async.execute {
+  %token, %results:2 = async.execute -> (!async.value<f32>, !async.value<f32>) {
     %cst1 = constant 1.000000e+00 : f32
     %cst2 = constant 2.000000e+00 : f32
     async.yield %cst1, %cst2 : f32, f32
-  } : !async.token, !async.value<f32>, !async.value<f32>
+  }
+
+  // CHECK: return %results#0, %results#1 : !async.value<f32>, !async.value<f32>
+  return %results#0, %results#1 : !async.value<f32>, !async.value<f32>
+}
+
+// CHECK-LABEL: @async_token_dependencies
+func @async_token_dependencies(%arg0: !async.token) -> !async.token {
+  // CHECK: async.execute [%arg0]
+  %token = async.execute [%arg0] {
+    async.yield
+  }
+
+  // CHECK: return %token : !async.token
+  return %token : !async.token
+}
+
+// CHECK-LABEL: @async_value_operands
+func @async_value_operands(%arg0: !async.value<f32>) -> !async.token {
+  // CHECK: async.execute (%arg0 as %arg1: !async.value<f32>) -> !async.value<f32>
+  %token, %results = async.execute (%arg0 as %arg1: !async.value<f32>) -> !async.value<f32> {
+    async.yield %arg1 : f32
+  }
+
+  // CHECK: return %token : !async.token
+  return %token : !async.token
+}
+
+// CHECK-LABEL: @async_token_and_value_operands
+func @async_token_and_value_operands(%arg0: !async.token, %arg1: !async.value<f32>) -> !async.token {
+  // CHECK: async.execute [%arg0] (%arg1 as %arg2: !async.value<f32>) -> !async.value<f32>
+  %token, %results = async.execute [%arg0] (%arg1 as %arg2: !async.value<f32>) -> !async.value<f32> {
+    async.yield %arg2 : f32
+  }
+
+  // CHECK: return %token : !async.token
+  return %token : !async.token
+}
 
-  // CHECK: return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
-  return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
+// CHECK-LABEL: @empty_tokens_or_values_operands
+func @empty_tokens_or_values_operands() {
+  // CHECK: async.execute {
+  %token0 = async.execute [] () -> () { async.yield }
+  // CHECK: async.execute {
+  %token1 = async.execute () -> () { async.yield }
+  // CHECK: async.execute {
+  %token2 = async.execute -> () { async.yield }
+  // CHECK: async.execute {
+  %token3 = async.execute () { async.yield }
+  // CHECK: async.execute {
+  %token4 = async.execute [] { async.yield }
+  return
 }