[mlir][spirv] Add spv.GroupNonUniformElect and spv.GroupNonUniformIAdd
authorLei Zhang <antiagainst@google.com>
Sun, 26 Jan 2020 15:19:24 +0000 (10:19 -0500)
committerLei Zhang <antiagainst@google.com>
Sun, 26 Jan 2020 15:20:40 +0000 (10:20 -0500)
Differential Revision: https://reviews.llvm.org/D73349

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir
mlir/test/Dialect/SPIRV/non-uniform-ops.mlir

index 6802471..3026350 100644 (file)
@@ -2353,6 +2353,53 @@ def SPV_FunctionControlAttr :
       SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const
     ]>;
 
+def SPV_GO_Reduce                     : I32EnumAttrCase<"Reduce", 0> {
+  list<Availability> availability = [
+    Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformBallot, SPV_C_Kernel]>
+  ];
+}
+def SPV_GO_InclusiveScan              : I32EnumAttrCase<"InclusiveScan", 1> {
+  list<Availability> availability = [
+    Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformBallot, SPV_C_Kernel]>
+  ];
+}
+def SPV_GO_ExclusiveScan              : I32EnumAttrCase<"ExclusiveScan", 2> {
+  list<Availability> availability = [
+    Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformBallot, SPV_C_Kernel]>
+  ];
+}
+def SPV_GO_ClusteredReduce            : I32EnumAttrCase<"ClusteredReduce", 3> {
+  list<Availability> availability = [
+    MinVersion<SPV_V_1_3>,
+    Capability<[SPV_C_GroupNonUniformClustered]>
+  ];
+}
+def SPV_GO_PartitionedReduceNV        : I32EnumAttrCase<"PartitionedReduceNV", 6> {
+  list<Availability> availability = [
+    Extension<[SPV_NV_shader_subgroup_partitioned]>,
+    Capability<[SPV_C_GroupNonUniformPartitionedNV]>
+  ];
+}
+def SPV_GO_PartitionedInclusiveScanNV : I32EnumAttrCase<"PartitionedInclusiveScanNV", 7> {
+  list<Availability> availability = [
+    Extension<[SPV_NV_shader_subgroup_partitioned]>,
+    Capability<[SPV_C_GroupNonUniformPartitionedNV]>
+  ];
+}
+def SPV_GO_PartitionedExclusiveScanNV : I32EnumAttrCase<"PartitionedExclusiveScanNV", 8> {
+  list<Availability> availability = [
+    Extension<[SPV_NV_shader_subgroup_partitioned]>,
+    Capability<[SPV_C_GroupNonUniformPartitionedNV]>
+  ];
+}
+
+def SPV_GroupOperationAttr :
+    SPV_I32EnumAttr<"GroupOperation", "valid SPIR-V GroupOperation", [
+      SPV_GO_Reduce, SPV_GO_InclusiveScan, SPV_GO_ExclusiveScan,
+      SPV_GO_ClusteredReduce, SPV_GO_PartitionedReduceNV,
+      SPV_GO_PartitionedInclusiveScanNV, SPV_GO_PartitionedExclusiveScanNV
+    ]>;
+
 def SPV_IF_Unknown      : I32EnumAttrCase<"Unknown", 0>;
 def SPV_IF_Rgba32f      : I32EnumAttrCase<"Rgba32f", 1> {
   list<Availability> availability = [
@@ -3108,7 +3155,9 @@ def SPV_OC_OpReturn                    : I32EnumAttrCase<"OpReturn", 253>;
 def SPV_OC_OpReturnValue               : I32EnumAttrCase<"OpReturnValue", 254>;
 def SPV_OC_OpUnreachable               : I32EnumAttrCase<"OpUnreachable", 255>;
 def SPV_OC_OpModuleProcessed           : I32EnumAttrCase<"OpModuleProcessed", 330>;
+def SPV_OC_OpGroupNonUniformElect      : I32EnumAttrCase<"OpGroupNonUniformElect", 333>;
 def SPV_OC_OpGroupNonUniformBallot     : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
+def SPV_OC_OpGroupNonUniformIAdd       : I32EnumAttrCase<"OpGroupNonUniformIAdd", 349>;
 def SPV_OC_OpSubgroupBallotKHR         : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 
 def SPV_OpcodeAttr :
@@ -3155,7 +3204,8 @@ def SPV_OpcodeAttr :
       SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
       SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
       SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed,
-      SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpSubgroupBallotKHR
+      SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot,
+      SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpSubgroupBallotKHR
     ]>;
 
 // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
index 1d63ed8..cc288f9 100644 (file)
@@ -72,5 +72,118 @@ def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
 
 // -----
 
-#endif // SPIRV_NON_UNIFORM_OPS
+def SPV_GroupNonUniformElectOp : SPV_Op<"GroupNonUniformElect", []> {
+  let summary = [{
+    Result is true only in the active invocation with the lowest id in the
+    group, otherwise result is false.
+  }];
+
+  let description = [{
+    Result Type must be a Boolean type.
+
+    Execution must be Workgroup or Subgroup Scope.
+
+    ### Custom assembly form
+
+    ```
+    scope ::= `"Workgroup"` | `"Subgroup"`
+    non-uniform-elect-op ::= ssa-id `=` `spv.GroupNonUniformElect` scope
+                             `:` `i1`
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.GroupNonUniformElect : i1
+    ```
+  }];
 
+  let availability = [
+    MinVersion<SPV_V_1_3>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_GroupNonUniform]>
+  ];
+
+  let arguments = (ins
+    SPV_ScopeAttr:$execution_scope
+  );
+
+  let results = (outs
+    SPV_Bool:$result
+  );
+
+  let builders = [
+    OpBuilder<[{Builder *builder, OperationState &state, spirv::Scope}]>
+  ];
+}
+
+// -----
+
+def SPV_GroupNonUniformIAddOp : SPV_Op<"GroupNonUniformIAdd", []> {
+  let summary = [{
+    An integer add group operation of all Value operands contributed active
+    by invocations in the group.
+  }];
+
+  let description = [{
+    Result Type  must be a scalar or vector of integer type.
+
+    Execution must be Workgroup or Subgroup Scope.
+
+    The identity I for Operation is 0. If Operation is ClusteredReduce,
+    ClusterSize must be specified.
+
+     The type of Value must be the same as Result Type.
+
+    ClusterSize is the size of cluster to use. ClusterSize must be a scalar
+    of integer type, whose Signedness operand is 0. ClusterSize must come
+    from a constant instruction. ClusterSize must be at least 1, and must be
+    a power of 2. If ClusterSize is greater than the declared SubGroupSize,
+    executing this instruction results in undefined behavior.
+
+    ### Custom assembly form
+
+    ```
+    scope ::= `"Workgroup"` | `"Subgroup"`
+    operation ::= `"Reduce"` | `"InclusiveScan"` | `"ExclusiveScan"` | ...
+    integer-scalar-vector-type ::= integer-type |
+                                 `vector<` integer-literal `x` integer-type `>`
+    non-uniform-iadd-op ::= ssa-id `=` `spv.GroupNonUniformIAdd` scope operation
+                            ssa-use ( `cluster_size` `(` ssa_use `)` )?
+                            `:` integer-scalar-vector-type
+    ```
+
+    For example:
+
+    ```
+    %four = spv.constant 4 : i32
+    %scalar = ... : i32
+    %vector = ... : vector<4xi32>
+    %0 = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %scalar : i32
+    %1 = spv.GroupNonUniformIAdd "Subgroup" "ClusteredReduce" %vector cluster_size(%four) : vector<4xi32>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPV_V_1_3>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_GroupNonUniformArithmetic, SPV_C_GroupNonUniformClustered, SPV_C_GroupNonUniformPartitionedNV]>
+  ];
+
+  let arguments = (ins
+    SPV_ScopeAttr:$execution_scope,
+    SPV_GroupOperationAttr:$group_operation,
+    SPV_ScalarOrVectorOf<SPV_Integer>:$value,
+    SPV_Optional<SPV_Integer>:$cluster_size
+  );
+
+  let results = (outs
+    SPV_ScalarOrVectorOf<SPV_Integer>:$result
+  );
+}
+
+// -----
+
+#endif // SPIRV_NON_UNIFORM_OPS
index 42e0df6..5e2fede 100644 (file)
@@ -32,10 +32,12 @@ using namespace mlir;
 static constexpr const char kAlignmentAttrName[] = "alignment";
 static constexpr const char kBranchWeightAttrName[] = "branch_weights";
 static constexpr const char kCallee[] = "callee";
+static constexpr const char kClusterSize[] = "cluster_size";
 static constexpr const char kDefaultValueAttrName[] = "default_value";
 static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
 static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
 static constexpr const char kFnNameAttrName[] = "fn";
+static constexpr const char kGroupOperationAttrName[] = "group_operation";
 static constexpr const char kIndicesAttrName[] = "indices";
 static constexpr const char kInitializerAttrName[] = "initializer";
 static constexpr const char kInterfaceAttrName[] = "interface";
@@ -53,9 +55,8 @@ static constexpr const char kVariableAttrName[] = "variable";
 // Common utility functions
 //===----------------------------------------------------------------------===//
 
-static LogicalResult extractValueFromConstOp(Operation *op,
-                                             int32_t &indexValue) {
-  auto constOp = dyn_cast<spirv::ConstantOp>(op);
+static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
+  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
   if (!constOp) {
     return failure();
   }
@@ -64,7 +65,7 @@ static LogicalResult extractValueFromConstOp(Operation *op,
   if (!integerValueAttr) {
     return failure();
   }
-  indexValue = integerValueAttr.getInt();
+  value = integerValueAttr.getInt();
   return success();
 }
 
@@ -1889,6 +1890,122 @@ static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
 }
 
 //===----------------------------------------------------------------------===//
+// spv.GroupNonUniformElectOp
+//===----------------------------------------------------------------------===//
+
+void spirv::GroupNonUniformElectOp::build(Builder *builder,
+                                          OperationState &state,
+                                          spirv::Scope scope) {
+  build(builder, state, builder->getI1Type(), scope);
+}
+
+static ParseResult parseGroupNonUniformElectOp(OpAsmParser &parser,
+                                               OperationState &state) {
+  spirv::Scope executionScope;
+  Type resultType;
+  if (parseEnumAttribute(executionScope, parser, state,
+                         kExecutionScopeAttrName) ||
+      parser.parseColonType(resultType))
+    return failure();
+
+  return parser.addTypeToList(resultType, state.types);
+}
+
+static void print(spirv::GroupNonUniformElectOp groupOp,
+                  OpAsmPrinter &printer) {
+  printer << spirv::GroupNonUniformElectOp::getOperationName() << " \""
+          << stringifyScope(groupOp.execution_scope())
+          << "\" : " << groupOp.getType();
+}
+
+static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
+  spirv::Scope scope = groupOp.execution_scope();
+  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+    return groupOp.emitOpError(
+        "execution scope must be 'Workgroup' or 'Subgroup'");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.GroupNonUniformIAddOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseGroupNonUniformIAddOp(OpAsmParser &parser,
+                                              OperationState &state) {
+  spirv::Scope executionScope;
+  spirv::GroupOperation groupOperation;
+  OpAsmParser::OperandType valueInfo;
+  if (parseEnumAttribute(executionScope, parser, state,
+                         kExecutionScopeAttrName) ||
+      parseEnumAttribute(groupOperation, parser, state,
+                         kGroupOperationAttrName) ||
+      parser.parseOperand(valueInfo))
+    return failure();
+
+  Optional<OpAsmParser::OperandType> clusterSizeInfo;
+  if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
+    clusterSizeInfo = OpAsmParser::OperandType();
+    if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
+        parser.parseRParen())
+      return failure();
+  }
+
+  Type resultType;
+  if (parser.parseColonType(resultType))
+    return failure();
+
+  if (parser.resolveOperand(valueInfo, resultType, state.operands))
+    return failure();
+
+  if (clusterSizeInfo.hasValue()) {
+    Type i32Type = parser.getBuilder().getIntegerType(32);
+    if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
+      return failure();
+  }
+
+  return parser.addTypeToList(resultType, state.types);
+}
+
+static void print(spirv::GroupNonUniformIAddOp groupOp, OpAsmPrinter &printer) {
+  printer << spirv::GroupNonUniformIAddOp::getOperationName() << " \""
+          << stringifyScope(groupOp.execution_scope()) << "\" \""
+          << stringifyGroupOperation(groupOp.group_operation()) << "\" "
+          << groupOp.value();
+  if (!groupOp.cluster_size().empty())
+    printer << " " << kClusterSize << '(' << groupOp.cluster_size() << ')';
+  printer << " : " << groupOp.getType();
+}
+
+static LogicalResult verify(spirv::GroupNonUniformIAddOp groupOp) {
+  spirv::Scope scope = groupOp.execution_scope();
+  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+    return groupOp.emitOpError(
+        "execution scope must be 'Workgroup' or 'Subgroup'");
+
+  spirv::GroupOperation operation = groupOp.group_operation();
+  if (operation == spirv::GroupOperation::ClusteredReduce &&
+      groupOp.cluster_size().empty())
+    return groupOp.emitOpError("cluster size operand must be provided for "
+                               "'ClusteredReduce' group operation");
+
+  if (!groupOp.cluster_size().empty()) {
+    Operation *sizeOp = (*groupOp.cluster_size().begin()).getDefiningOp();
+    int32_t clusterSize = 0;
+
+    // TODO(antiagainst): support specialization constant here.
+    if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
+      return groupOp.emitOpError(
+          "cluster size operand must come from a constant op");
+
+    if (!llvm::isPowerOf2_32(clusterSize))
+      return groupOp.emitOpError("cluster size operand must be a power of two");
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // spv.IAdd
 //===----------------------------------------------------------------------===//
 
index 282811e..ed414d8 100644 (file)
@@ -7,4 +7,26 @@ spv.module "Logical" "GLSL450" {
   %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
     spv.ReturnValue %0: vector<4xi32>
   }
+
+  // CHECK-LABEL: @group_non_uniform_elect
+  func @group_non_uniform_elect() -> i1 {
+    // CHECK: %{{.+}} = spv.GroupNonUniformElect "Workgroup" : i1
+    %0 = spv.GroupNonUniformElect "Workgroup" : i1
+    spv.ReturnValue %0: i1
+  }
+
+  // CHECK-LABEL: @group_non_uniform_iadd_reduce
+  func @group_non_uniform_iadd_reduce(%val: i32) -> i32 {
+    // CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %{{.+}} : i32
+    %0 = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %val : i32
+    spv.ReturnValue %0: i32
+  }
+
+  // CHECK-LABEL: @group_non_uniform_iadd_clustered_reduce
+  func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> {
+    %four = spv.constant 4 : i32
+    // CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xi32>
+    %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xi32>
+    spv.ReturnValue %0: vector<2xi32>
+  }
 }
index 483a731..8cbd2d2 100644 (file)
@@ -4,7 +4,7 @@
 // spv.GroupNonUniformBallot
 //===----------------------------------------------------------------------===//
 
-func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
+func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
   // CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32>
   %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
   return %0: vector<4xi32>
@@ -12,8 +12,83 @@ func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
 
 // -----
 
-func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
+func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
   // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
   %0 = spv.GroupNonUniformBallot "Device" %predicate : vector<4xi32>
   return %0: vector<4xi32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.GroupNonUniformElect
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_elect
+func @group_non_uniform_elect() -> i1 {
+  // CHECK: %{{.+}} = spv.GroupNonUniformElect "Workgroup" : i1
+  %0 = spv.GroupNonUniformElect "Workgroup" : i1
+  return %0: i1
+}
+
+// -----
+
+func @group_non_uniform_elect() -> i1 {
+  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  %0 = spv.GroupNonUniformElect "CrossDevice" : i1
+  return %0: i1
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.GroupNonUniformIAdd
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @group_non_uniform_iadd_reduce
+func @group_non_uniform_iadd_reduce(%val: i32) -> i32 {
+  // CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %{{.+}} : i32
+  %0 = spv.GroupNonUniformIAdd "Workgroup" "Reduce" %val : i32
+  return %0: i32
+}
+
+// CHECK-LABEL: @group_non_uniform_iadd_clustered_reduce
+func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> {
+  %four = spv.constant 4 : i32
+  // CHECK: %{{.+}} = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %{{.+}} cluster_size(%{{.+}}) : vector<2xi32>
+  %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%four) : vector<2xi32>
+  return %0: vector<2xi32>
+}
+
+// -----
+
+func @group_non_uniform_iadd_reduce(%val: i32) -> i32 {
+  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  %0 = spv.GroupNonUniformIAdd "Device" "Reduce" %val : i32
+  return %0: i32
+}
+
+// -----
+
+func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> {
+  // expected-error @+1 {{cluster size operand must be provided for 'ClusteredReduce' group operation}}
+  %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val : vector<2xi32>
+  return %0: vector<2xi32>
+}
+
+// -----
+
+func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>, %size: i32) -> vector<2xi32> {
+  // expected-error @+1 {{cluster size operand must come from a constant op}}
+  %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%size) : vector<2xi32>
+  return %0: vector<2xi32>
+}
+
+// -----
+
+func @group_non_uniform_iadd_clustered_reduce(%val: vector<2xi32>) -> vector<2xi32> {
+  %five = spv.constant 5 : i32
+  // expected-error @+1 {{cluster size operand must be a power of two}}
+  %0 = spv.GroupNonUniformIAdd "Workgroup" "ClusteredReduce" %val cluster_size(%five) : vector<2xi32>
+  return %0: vector<2xi32>
+}