Make VariableOp instructions be in the first block in the function.
authorHanhan Wang <hanchung@google.com>
Wed, 13 Nov 2019 02:58:36 +0000 (18:58 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 13 Nov 2019 02:59:15 +0000 (18:59 -0800)
Since VariableOp is serialized during processBlock, we add two more fields,
`functionHeader` and `functionBody`, to collect instructions for a function.
After all the blocks have been processed, we append them to the `functions`.

Also, fix a bug in processGlobalVariableOp. The global variables should be
encoded into `typesGlobalValues`.

PiperOrigin-RevId: 280105366

mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/Serialization/constant.mlir
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

index 0213420..0ff79d9 100644 (file)
@@ -177,6 +177,8 @@ private:
   /// Processes a SPIR-V function op.
   LogicalResult processFuncOp(FuncOp op);
 
+  LogicalResult processVariableOp(spirv::VariableOp op);
+
   /// Process a SPIR-V GlobalVariableOp
   LogicalResult processGlobalVariableOp(spirv::GlobalVariableOp varOp);
 
@@ -374,6 +376,19 @@ private:
   SmallVector<uint32_t, 0> typesGlobalValues;
   SmallVector<uint32_t, 0> functions;
 
+  /// `functionHeader` contains all the instructions that must be in the first
+  /// block in the function, and `functionBody` contains the rest. After
+  /// processing FuncOp, the encoded instructions of a function are appended to
+  /// `functions`. An example of instructions in `functionHeader` in order:
+  /// OpFunction ...
+  /// OpFunctionParameter ...
+  /// OpFunctionParameter ...
+  /// OpLabel ...
+  /// OpVariable ...
+  /// OpVariable ...
+  SmallVector<uint32_t, 0> functionHeader;
+  SmallVector<uint32_t, 0> functionBody;
+
   /// Map from type used in SPIR-V module to their <id>s.
   DenseMap<Type, uint32_t> typeIDMap;
 
@@ -671,6 +686,7 @@ Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex,
 
 LogicalResult Serializer::processFuncOp(FuncOp op) {
   LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
+  assert(functionHeader.empty() && functionBody.empty());
 
   uint32_t fnTypeID = 0;
   // Generate type of the function.
@@ -694,7 +710,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
   // TODO : Support other function control options.
   operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None));
   operands.push_back(fnTypeID);
-  encodeInstructionInto(functions, spirv::Opcode::OpFunction, operands);
+  encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
 
   // Add function name.
   if (failed(processName(funcID, op.getName()))) {
@@ -709,7 +725,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
     }
     auto argValueID = getNextID();
     valueIDMap[arg] = argValueID;
-    encodeInstructionInto(functions, spirv::Opcode::OpFunctionParameter,
+    encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
                           {argTypeID, argValueID});
   }
 
@@ -718,9 +734,18 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
     return op.emitError("external function is unhandled");
   }
 
+  // Some instructions (e.g., OpVariable) in a function must be in the first
+  // block in the function. These instructions will be put in functionHeader.
+  // Thus, we put the label in functionHeader first, and omit it from the first
+  // block.
+  encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
+                        {getOrCreateBlockID(&op.front())});
+  processBlock(&op.front(), /*omitLabel=*/true);
   if (failed(visitInPrettyBlockOrder(
-          &op.front(), [&](Block *block) { return processBlock(block); })))
+          &op.front(), [&](Block *block) { return processBlock(block); },
+          /*skipHeader=*/true))) {
     return failure();
+  }
 
   // There might be OpPhi instructions who have value references needing to fix.
   for (auto deferredValue : deferredPhiValues) {
@@ -730,14 +755,63 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
                             << " to id = " << id << '\n');
     assert(id && "OpPhi references undefined value!");
     for (size_t offset : deferredValue.second)
-      functions[offset] = id;
+      functionBody[offset] = id;
   }
   deferredPhiValues.clear();
 
   LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
                           << "' --\n");
   // Insert OpFunctionEnd.
-  return encodeInstructionInto(functions, spirv::Opcode::OpFunctionEnd, {});
+  if (failed(encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd,
+                                   {}))) {
+    return failure();
+  }
+
+  functions.append(functionHeader.begin(), functionHeader.end());
+  functions.append(functionBody.begin(), functionBody.end());
+  functionHeader.clear();
+  functionBody.clear();
+
+  return success();
+}
+
+LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
+  SmallVector<uint32_t, 4> operands;
+  SmallVector<StringRef, 2> elidedAttrs;
+  uint32_t resultID = 0;
+  uint32_t resultTypeID = 0;
+  if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
+    return failure();
+  }
+  operands.push_back(resultTypeID);
+  resultID = getNextID();
+  valueIDMap[op.getResult()] = resultID;
+  operands.push_back(resultID);
+  auto attr = op.getAttr(spirv::attributeName<spirv::StorageClass>());
+  if (attr) {
+    operands.push_back(static_cast<uint32_t>(
+        attr.cast<IntegerAttr>().getValue().getZExtValue()));
+  }
+  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
+  for (auto arg : op.getODSOperands(0)) {
+    auto argID = getValueID(arg);
+    if (!argID) {
+      return emitError(op.getLoc(), "operand 0 has a use before def");
+    }
+    operands.push_back(argID);
+  }
+  encodeInstructionInto(functionHeader, spirv::getOpcode<spirv::VariableOp>(),
+                        operands);
+  for (auto attr : op.getAttrs()) {
+    if (llvm::any_of(elidedAttrs,
+                     [&](StringRef elided) { return attr.first.is(elided); })) {
+      continue;
+    }
+    if (failed(processDecoration(op.getLoc(), resultID, attr))) {
+      return failure();
+    }
+  }
+  return success();
 }
 
 LogicalResult
@@ -789,7 +863,7 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
     elidedAttrs.push_back("initializer");
   }
 
-  if (failed(encodeInstructionInto(functions, spirv::Opcode::OpVariable,
+  if (failed(encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable,
                                    operands))) {
     elidedAttrs.push_back("initializer");
     return failure();
@@ -1360,7 +1434,7 @@ Serializer::processBlock(Block *block, bool omitLabel,
                << "[block] " << block << " (id = " << blockID << ")\n");
 
     // Emit OpLabel for this block.
-    encodeInstructionInto(functions, spirv::Opcode::OpLabel, {blockID});
+    encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
   }
 
   // Emit OpPhi instructions for block arguments, if any.
@@ -1431,7 +1505,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
         // The op generating this value hasn't been visited yet so we don't have
         // an <id> assigned yet. Record this to fix up later.
         LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
-        deferredPhiValues[value].push_back(functions.size() + 1 +
+        deferredPhiValues[value].push_back(functionBody.size() + 1 +
                                            phiArgs.size());
       } else {
         LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
@@ -1441,7 +1515,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
       phiArgs.push_back(predBlockId);
     }
 
-    encodeInstructionInto(functions, spirv::Opcode::OpPhi, phiArgs);
+    encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
     valueIDMap[arg] = phiID;
   }
 
@@ -1465,7 +1539,7 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
   auto emitSelectionMerge = [&]() {
     // TODO(antiagainst): properly support loop control here
     encodeInstructionInto(
-        functions, spirv::Opcode::OpSelectionMerge,
+        functionBody, spirv::Opcode::OpSelectionMerge,
         {mergeID, static_cast<uint32_t>(spirv::LoopControl::None)});
   };
   // For structured selection, we cannot have blocks in the selection construct
@@ -1489,7 +1563,7 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
   // contains a spv._merge op, itself. But we need to have an OpLabel
   // instruction to start a new SPIR-V block for ops following this SelectionOp.
   // The block should use the <id> for the merge block.
-  return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID});
+  return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
 }
 
 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
@@ -1513,7 +1587,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
   // preceding and following ops. So we need to emit unconditional branches to
   // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
   // afterwards.
-  encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID});
+  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
 
   // We omit the LoopOp's entry block and start serialization from the loop
   // header block. The entry block should not contain any additional ops other
@@ -1532,7 +1606,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
   auto emitLoopMerge = [&]() {
     // TODO(antiagainst): properly support loop control here
     encodeInstructionInto(
-        functions, spirv::Opcode::OpLoopMerge,
+        functionBody, spirv::Opcode::OpLoopMerge,
         {mergeID, continueID, static_cast<uint32_t>(spirv::LoopControl::None)});
   };
   if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
@@ -1554,7 +1628,7 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
   // a spv._merge op, itself. But we need to have an OpLabel instruction to
   // start a new SPIR-V block for ops following this LoopOp. The block should
   // use the <id> for the merge block.
-  return encodeInstructionInto(functions, spirv::Opcode::OpLabel, {mergeID});
+  return encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
 }
 
 LogicalResult Serializer::processBranchConditionalOp(
@@ -1569,12 +1643,12 @@ LogicalResult Serializer::processBranchConditionalOp(
       arguments.push_back(val.cast<IntegerAttr>().getInt());
   }
 
-  return encodeInstructionInto(functions, spirv::Opcode::OpBranchConditional,
+  return encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
                                arguments);
 }
 
 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
-  return encodeInstructionInto(functions, spirv::Opcode::OpBranch,
+  return encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
                                {getOrCreateBlockID(branchOp.getTarget())});
 }
 
@@ -1610,7 +1684,7 @@ LogicalResult Serializer::encodeExtensionInstruction(
   extInstOperands.push_back(setID);
   extInstOperands.push_back(extensionOpcode);
   extInstOperands.append(std::next(operands.begin(), 2), operands.end());
-  return encodeInstructionInto(functions, spirv::Opcode::OpExtInst,
+  return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
                                extInstOperands);
 }
 
@@ -1658,6 +1732,9 @@ LogicalResult Serializer::processOperation(Operation *op) {
   if (auto fnOp = dyn_cast<FuncOp>(op)) {
     return processFuncOp(fnOp);
   }
+  if (auto varOp = dyn_cast<spirv::VariableOp>(op)) {
+    return processVariableOp(varOp);
+  }
   if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
     return processGlobalVariableOp(varOp);
   }
@@ -1736,7 +1813,7 @@ Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) {
     operands.push_back(operand);
   }
 
-  return encodeInstructionInto(functions, spirv::Opcode::OpControlBarrier,
+  return encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier,
                                operands);
 }
 
@@ -1783,7 +1860,7 @@ Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) {
     operands.push_back(operand);
   }
 
-  return encodeInstructionInto(functions, spirv::Opcode::OpMemoryBarrier,
+  return encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier,
                                operands);
 }
 
@@ -1814,7 +1891,7 @@ Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
     valueIDMap[op.getResult(0)] = funcCallID;
   }
 
-  return encodeInstructionInto(functions, spirv::Opcode::OpFunctionCall,
+  return encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall,
                                operands);
 }
 
index acfa40a..9531209 100644 (file)
@@ -165,4 +165,17 @@ spv.module "Logical" "GLSL450" {
     %1 = spv.IAdd %0, %0 : i32
     spv.ReturnValue %1 : i32
   }
+
+  // CHECK-LABEL: @const_variable
+  func @const_variable(%arg0 : i32, %arg1 : i32) -> () {
+    // CHECK: %[[CONST:.*]] = spv.constant 5 : i32
+    // CHECK: spv.Variable init(%[[CONST]]) : !spv.ptr<i32, Function>
+    // CHECK: spv.IAdd %arg0, %arg1
+    %0 = spv.IAdd %arg0, %arg1 : i32
+    %1 = spv.constant 5 : i32
+    %2 = spv.Variable init(%1) : !spv.ptr<i32, Function>
+    %3 = spv.Load "Function" %2 : i32
+    %4 = spv.IAdd %0, %3 : i32
+    spv.Return
+  }
 }
index 34bfd40..f39295a 100644 (file)
@@ -231,7 +231,7 @@ static void emitSerializationFunction(const Record *attrClass,
                   record->getValueAsInt("extendedInstOpcode"), operands);
   } else {
     os << formatv("  encodeInstructionInto("
-                  "functions, spirv::getOpcode<{0}>(), {1});\n",
+                  "functionBody, spirv::getOpcode<{0}>(), {1});\n",
                   op.getQualCppClassName(), operands);
   }