Update UndefOp (de)serialization to generate OpUndef at module level.
authorMahesh Ravishankar <ravishankarm@google.com>
Mon, 7 Oct 2019 19:56:08 +0000 (12:56 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 7 Oct 2019 19:56:38 +0000 (12:56 -0700)
The SPIR-V spec recommends all OpUndef instructions be generated at
module level. For the SPIR-V dialect its better for UndefOp to produce
an SSA value for use with other instructions. If UndefOp is to be used
at module level, it cannot produce an SSA value (use of this SSA value
within FuncOp would need implicit capture). To satisfy needs of the
SPIR-V spec while making it simpler to represent UndefOp in the SPIR-V
dialect, the serialization is updated to create OpUndef instruction
at module scope.

PiperOrigin-RevId: 273355526

mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/Serialization/undef.mlir
mlir/test/Dialect/SPIRV/ops.mlir

index 178d358..aa6661a 100644 (file)
@@ -485,7 +485,7 @@ def SPV_StoreOp : SPV_Op<"Store", []> {
 
 // -----
 
-def SPV_UndefOp : SPV_Op<"Undef", []> {
+def SPV_UndefOp : SPV_Op<"undef", []> {
   let summary = "Make an intermediate object whose value is undefined.";
 
   let description = [{
@@ -498,14 +498,14 @@ def SPV_UndefOp : SPV_Op<"Undef", []> {
     ### Custom assembly form
 
     ``` {.ebnf}
-    undef-op ::= `spv.Undef` `:` sprirv-type
+    undef-op ::= `spv.undef` `:` sprirv-type
     ```
 
     For example:
 
     ```
-    %0 = spv.Undef : f32
-    %1 = spv.Undef : !spv.struct<!spv.array<4 x vector<4xi32>>>
+    %0 = spv.undef : f32
+    %1 = spv.undef : !spv.struct<!spv.array<4 x vector<4xi32>>>
     ```
   }];
 
@@ -516,6 +516,9 @@ def SPV_UndefOp : SPV_Op<"Undef", []> {
   );
 
   let verifier = [{ return success(); }];
+
+  let hasOpcode = 0;
+  let autogenSerialization = 0;
 }
 
 // -----
index 3fbf5ee..6c194b2 100644 (file)
@@ -173,6 +173,9 @@ private:
   /// Gets type for a given result <id>.
   Type getType(uint32_t id) { return typeMap.lookup(id); }
 
+  /// Get the type associated with the result <id> of an OpUndef.
+  Type getUndefType(uint32_t id) { return undefMap.lookup(id); }
+
   /// Returns true if the given `type` is for SPIR-V void type.
   bool isVoidType(Type type) const { return type.isa<NoneType>(); }
 
@@ -306,6 +309,10 @@ private:
                                    ArrayRef<uint32_t> operands,
                                    bool deferInstructions = true);
 
+  /// Processes a OpUndef instruction. Adds a spv.Undef operation at the current
+  /// insertion point.
+  LogicalResult processUndef(ArrayRef<uint32_t> operands);
+
   /// Method to dispatch to the specialized deserialization function for an
   /// operation in SPIR-V dialect that is a mirror of an instruction in the
   /// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for
@@ -396,6 +403,9 @@ private:
   // Result <id> to value mapping.
   DenseMap<uint32_t, Value *> valueMap;
 
+  // Mapping from result <id> to undef value of a type.
+  DenseMap<uint32_t, Type> undefMap;
+
   // Result <id> to name mapping.
   DenseMap<uint32_t, StringRef> nameMap;
 
@@ -1794,6 +1804,9 @@ Value *Deserializer::getValue(uint32_t id) {
         opBuilder.getSymbolRefAttr(constOp.getOperation()));
     return referenceOfOp.reference();
   }
+  if (auto undef = getUndefType(id)) {
+    return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
+  }
   return valueMap.lookup(id);
 }
 
@@ -1913,12 +1926,26 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
     return processSelectionMerge(operands);
   case spirv::Opcode::OpLoopMerge:
     return processLoopMerge(operands);
+  case spirv::Opcode::OpUndef:
+    return processUndef(operands);
   default:
     break;
   }
   return dispatchToAutogenDeserialization(opcode, operands);
 }
 
+LogicalResult Deserializer::processUndef(ArrayRef<uint32_t> operands) {
+  if (operands.size() != 2) {
+    return emitError(unknownLoc, "OpUndef instruction must have two operands");
+  }
+  auto type = getType(operands[0]);
+  if (!type) {
+    return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
+  }
+  undefMap[operands[1]] = type;
+  return success();
+}
+
 LogicalResult Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
   if (operands.size() < 4) {
     return emitError(unknownLoc,
index f2314ff..608de98 100644 (file)
@@ -127,6 +127,12 @@ private:
 
   LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
 
+  /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA
+  /// value to use with other operations. The SPIR-V spec recommends that
+  /// OpUndef be generated at module level. The serialization generates an
+  /// OpUndef for each type needed at module level.
+  LogicalResult processUndefOp(spirv::UndefOp op);
+
   /// Emit OpName for the given `resultID`.
   LogicalResult processName(uint32_t resultID, StringRef name);
 
@@ -333,6 +339,9 @@ private:
   /// Map from blocks to their <id>s.
   DenseMap<Block *, uint32_t> blockIDMap;
 
+  /// Map from the Type to the <id> that represents undef value of that type.
+  DenseMap<Type, uint32_t> undefValIDMap;
+
   /// Map from results of normal operations to their <id>s.
   DenseMap<Value *, uint32_t> valueIDMap;
 
@@ -449,6 +458,22 @@ LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
   return failure();
 }
 
+LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
+  auto undefType = op.getType();
+  auto &id = undefValIDMap[undefType];
+  if (!id) {
+    id = getNextID();
+    uint32_t typeID = 0;
+    if (failed(processType(op.getLoc(), undefType, typeID)) ||
+        failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
+                                     {typeID, id}))) {
+      return failure();
+    }
+  }
+  valueIDMap[op.getResult()] = id;
+  return success();
+}
+
 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
                                             NamedAttribute attr) {
   auto attrName = attr.first.strref();
@@ -1503,6 +1528,9 @@ LogicalResult Serializer::processOperation(Operation *op) {
   if (auto specConstOp = dyn_cast<spirv::SpecConstantOp>(op)) {
     return processSpecConstantOp(specConstOp);
   }
+  if (auto undefOp = dyn_cast<spirv::UndefOp>(op)) {
+    return processUndefOp(undefOp);
+  }
 
   // Then handle all the ops that directly mirror SPIR-V instructions with
   // auto-generated methods.
index 7c2ccdb..d5cb410 100644 (file)
@@ -1,15 +1,33 @@
-// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s
 
 spv.module "Logical" "GLSL450" {
   func @foo() -> () {
-    // CHECK: {{%.*}} = spv.Undef : f32
-    %0 = spv.Undef : f32
-    // CHECK: {{%.*}} = spv.Undef : vector<4xi32>
-    %1 = spv.Undef : vector<4xi32>
-    // CHECK: {{%.*}} = spv.Undef : !spv.array<4 x !spv.array<4 x i32>>
-    %2 = spv.Undef : !spv.array<4x!spv.array<4xi32>>
-    // CHECK: {{%.*}} = spv.Undef : !spv.ptr<!spv.struct<f32>, StorageBuffer>
-    %3 = spv.Undef : !spv.ptr<!spv.struct<f32>, StorageBuffer>
+    // CHECK: {{%.*}} = spv.undef : f32
+    // CHECK-NEXT: {{%.*}} = spv.undef : f32
+    %0 = spv.undef : f32
+    %1 = spv.undef : f32
+    %2 = spv.FAdd %0, %1 : f32
+    // CHECK: {{%.*}} = spv.undef : vector<4xi32>
+    %3 = spv.undef : vector<4xi32>
+    %4 = spv.CompositeExtract %3[1 : i32] : vector<4xi32>
+    // CHECK: {{%.*}} = spv.undef : !spv.array<4 x !spv.array<4 x i32>>
+    %5 = spv.undef : !spv.array<4x!spv.array<4xi32>>
+    %6 = spv.CompositeExtract %5[1 : i32, 2 : i32] : !spv.array<4x!spv.array<4xi32>>
+    // CHECK: {{%.*}} = spv.undef : !spv.ptr<!spv.struct<f32>, StorageBuffer>
+    %7 = spv.undef : !spv.ptr<!spv.struct<f32>, StorageBuffer>
+    %8 = spv.constant 0 : i32
+    %9 = spv.AccessChain %7[%8] : !spv.ptr<!spv.struct<f32>, StorageBuffer>
+    spv.Return
+  }
+}
+
+// -----
+
+spv.module "Logical" "GLSL450" {
+  // CHECK: func {{@.*}}
+  func @ignore_unused_undef() -> () {
+    // CHECK-NEXT: spv.Return
+    %0 = spv.undef : f32
     spv.Return
   }
 }
\ No newline at end of file
index 63c034f..639e10f 100644 (file)
@@ -968,12 +968,12 @@ spv.module "Logical" "GLSL450" {
 // -----
 
 //===----------------------------------------------------------------------===//
-// spv.Undef
+// spv.undef
 //===----------------------------------------------------------------------===//
 
 func @undef() -> () {
-  %0 = spv.Undef : f32
-  %1 = spv.Undef : vector<4xf32>
+  %0 = spv.undef : f32
+  %1 = spv.undef : vector<4xf32>
   spv.Return
 }
 
@@ -981,7 +981,7 @@ func @undef() -> () {
 
 func @undef() -> () {
   // expected-error @+2{{expected non-function type}}
-  %0 = spv.Undef :
+  %0 = spv.undef :
   spv.Return
 }
 
@@ -989,7 +989,7 @@ func @undef() -> () {
 
 func @undef() -> () {
   // expected-error @+2{{expected ':'}}
-  %0 = spv.Undef
+  %0 = spv.undef
   spv.Return
 }