[MLIR][SPIRV] Extend _reference_of to support SpecConstantCompositeOp.
authorergawy <kareem.ergawy@gmail.com>
Mon, 5 Oct 2020 20:39:39 +0000 (16:39 -0400)
committerLei Zhang <antiagainst@google.com>
Mon, 5 Oct 2020 21:04:55 +0000 (17:04 -0400)
Adds support for SPIR-V composite speciailization constants to spv._reference_of.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D88732

mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir
mlir/test/Dialect/SPIRV/structure-ops.mlir

index 0e866f0..c64606b 100644 (file)
@@ -472,7 +472,7 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> {
   let summary = "Reference a specialization constant.";
 
   let description = [{
-    Specialization constant in module scope are defined using symbol names.
+    Specialization constants in module scope are defined using symbol names.
     This op generates an SSA value that can be used to refer to the symbol
     within function scope for use in ops that expect an SSA value.
     This operation has no corresponding SPIR-V instruction; it's merely used
index 363785e..ad25ecb 100644 (file)
@@ -2568,17 +2568,27 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
-  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(
-      SymbolTable::lookupNearestSymbolFrom(referenceOfOp.getParentOp(),
-                                           referenceOfOp.spec_const()));
-  if (!specConstOp) {
-    return referenceOfOp.emitOpError("expected spv.specConstant symbol");
-  }
-  if (referenceOfOp.reference().getType() !=
-      specConstOp.default_value().getType()) {
+  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
+      referenceOfOp.getParentOp(), referenceOfOp.spec_const());
+  Type constType;
+
+  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
+  if (specConstOp)
+    constType = specConstOp.default_value().getType();
+
+  auto specConstCompositeOp =
+      dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
+  if (specConstCompositeOp)
+    constType = specConstCompositeOp.type();
+
+  if (!specConstOp && !specConstCompositeOp)
+    return referenceOfOp.emitOpError(
+        "expected spv.specConstant or spv.SpecConstantComposite symbol");
+
+  if (referenceOfOp.reference().getType() != constType)
     return referenceOfOp.emitOpError("result type mismatch with the referenced "
                                      "specialization constant's type");
-  }
+
   return success();
 }
 
index 153540d..33966f8 100644 (file)
@@ -187,6 +187,11 @@ private:
     return specConstMap.lookup(id);
   }
 
+  /// Gets the composite specialization constant with the given result <id>.
+  spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id) {
+    return specConstCompositeMap.lookup(id);
+  }
+
   /// Creates a spirv::SpecConstantOp.
   spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
                                            Attribute defaultValue);
@@ -461,9 +466,12 @@ private:
   /// (and type) here. Later when it's used, we materialize the constant.
   DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
 
-  // Result <id> to variable mapping.
+  // Result <id> to spec constant mapping.
   DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
 
+  // Result <id> to composite spec constant mapping.
+  DenseMap<uint32_t, spirv::SpecConstantCompositeOp> specConstCompositeMap;
+
   // Result <id> to variable mapping.
   DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
 
@@ -1565,7 +1573,8 @@ Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
            << operands[0];
   }
 
-  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(operands[1]));
+  auto resultID = operands[1];
+  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
 
   SmallVector<Attribute, 4> elements;
   elements.reserve(operands.size() - 2);
@@ -1574,9 +1583,10 @@ Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
     elements.push_back(opBuilder.getSymbolRefAttr(elementInfo));
   }
 
-  opBuilder.create<spirv::SpecConstantCompositeOp>(
+  auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
       unknownLoc, TypeAttr::get(resultType), symName,
       opBuilder.getArrayAttr(elements));
+  specConstCompositeMap[resultID] = op;
 
   return success();
 }
@@ -2208,6 +2218,12 @@ Value Deserializer::getValue(uint32_t id) {
         opBuilder.getSymbolRefAttr(constOp.getOperation()));
     return referenceOfOp.reference();
   }
+  if (auto constCompositeOp = getSpecConstantComposite(id)) {
+    auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+        unknownLoc, constCompositeOp.type(),
+        opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
+    return referenceOfOp.reference();
+  }
   if (auto undef = getUndefType(id)) {
     return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
   }
index 0df9301..2cbfcc6 100644 (file)
@@ -12,6 +12,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // CHECK: spv.specConstant @sc_float spec_id(5) = 1.000000e+00 : f32
   spv.specConstant @sc_float spec_id(5) = 1. : f32
 
+  // CHECK: spv.specConstantComposite @scc (@sc_int, @sc_int) : !spv.array<2 x i32>
+  spv.specConstantComposite @scc (@sc_int, @sc_int) : !spv.array<2 x i32>
+
   // CHECK-LABEL: @use
   spv.func @use() -> (i32) "None" {
     // We materialize a `spv._reference_of` op at every use of a
@@ -24,6 +27,43 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     %1 = spv.IAdd %0, %0 : i32
     spv.ReturnValue %1 : i32
   }
+
+  // CHECK-LABEL: @use
+  spv.func @use_composite() -> (i32) "None" {
+    // We materialize a `spv._reference_of` op at every use of a
+    // specialization constant in the deserializer. So two ops here.
+    // CHECK: %[[USE1:.*]] = spv._reference_of @scc : !spv.array<2 x i32>
+    // CHECK: %[[ITM0:.*]] = spv.CompositeExtract %[[USE1]][0 : i32] : !spv.array<2 x i32>
+    // CHECK: %[[USE2:.*]] = spv._reference_of @scc : !spv.array<2 x i32>
+    // CHECK: %[[ITM1:.*]] = spv.CompositeExtract %[[USE2]][1 : i32] : !spv.array<2 x i32>
+    // CHECK: spv.IAdd %[[ITM0]], %[[ITM1]]
+
+    %0 = spv._reference_of @scc : !spv.array<2 x i32>
+    %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
+    %2 = spv.CompositeExtract %0[1 : i32] : !spv.array<2 x i32>
+    %3 = spv.IAdd %1, %2 : i32
+    spv.ReturnValue %3 : i32
+  }
+}
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+
+  spv.specConstant @sc_f32_1 = 1.5 : f32
+  spv.specConstant @sc_f32_2 = 2.5 : f32
+  spv.specConstant @sc_f32_3 = 3.5 : f32
+
+  spv.specConstant @sc_i32_1 = 1   : i32
+
+  // CHECK: spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32>
+  spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32>
+
+  // CHECK: spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<i32, f32, f32>
+  spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<i32, f32, f32>
+
+  // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32>
+  spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32>
 }
 
 // -----
index 765eba9..7bb98b9 100644 (file)
@@ -496,6 +496,8 @@ spv.module Logical GLSL450 {
   spv.specConstant @sc2 = 42 : i64
   spv.specConstant @sc3 = 1.5 : f32
 
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i1, i64, f32>
+
   // CHECK-LABEL: @reference
   spv.func @reference() -> i1 "None" {
     // CHECK: spv._reference_of @sc1 : i1
@@ -503,6 +505,14 @@ spv.module Logical GLSL450 {
     spv.ReturnValue %0 : i1
   }
 
+  // CHECK-LABEL: @reference_composite
+  spv.func @reference_composite() -> i1 "None" {
+    // CHECK: spv._reference_of @scc : !spv.struct<i1, i64, f32>
+    %0 = spv._reference_of @scc : !spv.struct<i1, i64, f32>
+    %1 = spv.CompositeExtract %0[0 : i32] : !spv.struct<i1, i64, f32>
+    spv.ReturnValue %1 : i1
+  }
+
   // CHECK-LABEL: @initialize
   spv.func @initialize() -> i64 "None" {
     // CHECK: spv._reference_of @sc2 : i64
@@ -534,9 +544,21 @@ func @reference_of() {
 
 // -----
 
+spv.specConstant @sc = 5 : i32
+spv.specConstantComposite @scc (@sc) : !spv.array<1 x i32>
+
+func @reference_of_composite() {
+  // CHECK: spv._reference_of @scc : !spv.array<1 x i32>
+  %0 = spv._reference_of @scc : !spv.array<1 x i32>
+  %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<1 x i32>
+  return
+}
+
+// -----
+
 spv.module Logical GLSL450 {
   spv.func @foo() -> () "None" {
-    // expected-error @+1 {{expected spv.specConstant symbol}}
+    // expected-error @+1 {{expected spv.specConstant or spv.SpecConstantComposite symbol}}
     %0 = spv._reference_of @sc : i32
     spv.Return
   }
@@ -555,6 +577,18 @@ spv.module Logical GLSL450 {
 
 // -----
 
+spv.module Logical GLSL450 {
+  spv.specConstant @sc = 42 : i32
+  spv.specConstantComposite @scc (@sc) : !spv.array<1 x i32>
+  spv.func @foo() -> () "None" {
+    // expected-error @+1 {{result type mismatch with the referenced specialization constant's type}}
+    %0 = spv._reference_of @scc : f32
+    spv.Return
+  }
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.specConstant
 //===----------------------------------------------------------------------===//