Expose a convenience function to add interface attributes to a function.
authorMahesh Ravishankar <ravishankarm@google.com>
Wed, 11 Dec 2019 20:21:13 +0000 (12:21 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 11 Dec 2019 20:21:42 +0000 (12:21 -0800)
PiperOrigin-RevId: 285036647

mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp

index d8d39e0..1619a5e 100644 (file)
@@ -67,12 +67,6 @@ namespace spirv {
 Value *getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin,
                                OpBuilder &builder);
 
-/// Legalizes a function as an entry function.
-FuncOp lowerAsEntryFunction(FuncOp funcOp, SPIRVTypeConverter &typeConverter,
-                            ConversionPatternRewriter &rewriter,
-                            ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo,
-                            spirv::EntryPointABIAttr entryPointInfo);
-
 /// Attribute name for specifying argument ABI information.
 StringRef getInterfaceVarABIAttrName();
 
@@ -89,6 +83,18 @@ StringRef getEntryPointABIAttrName();
 EntryPointABIAttr getEntryPointABIAttr(ArrayRef<int32_t> localSize,
                                        MLIRContext *context);
 
+/// Legalizes a function as an entry function.
+FuncOp lowerAsEntryFunction(FuncOp funcOp, SPIRVTypeConverter &typeConverter,
+                            ConversionPatternRewriter &rewriter,
+                            spirv::EntryPointABIAttr entryPointInfo,
+                            ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo);
+
+/// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its
+/// arguments
+LogicalResult setABIAttrs(FuncOp funcOp,
+                          spirv::EntryPointABIAttr entryPointInfo,
+                          ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo);
+
 } // namespace spirv
 } // namespace mlir
 
index 74d105e..2b39c0d 100644 (file)
@@ -224,7 +224,7 @@ KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
   auto entryPointAttr =
       spirv::getEntryPointABIAttr(workGroupSizeAsInt32, context);
   FuncOp newFuncOp = spirv::lowerAsEntryFunction(
-      funcOp, typeConverter, rewriter, argABI, entryPointAttr);
+      funcOp, typeConverter, rewriter, entryPointAttr, argABI);
   if (!newFuncOp) {
     return matchFailure();
   }
index 241a588..67c036d 100644 (file)
@@ -252,8 +252,8 @@ Value *mlir::spirv::getBuiltinVariableValue(Operation *op,
 FuncOp mlir::spirv::lowerAsEntryFunction(
     FuncOp funcOp, SPIRVTypeConverter &typeConverter,
     ConversionPatternRewriter &rewriter,
-    ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo,
-    spirv::EntryPointABIAttr entryPointInfo) {
+    spirv::EntryPointABIAttr entryPointInfo,
+    ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
   auto fnType = funcOp.getType();
   if (fnType.getNumResults()) {
     funcOp.emitError("SPIR-V lowering only supports entry functions"
@@ -282,11 +282,18 @@ FuncOp mlir::spirv::lowerAsEntryFunction(
   rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
   rewriter.eraseOp(funcOp);
 
+  spirv::setABIAttrs(newFuncOp, entryPointInfo, argABIInfo);
+  return newFuncOp;
+}
+
+LogicalResult
+mlir::spirv::setABIAttrs(FuncOp funcOp, spirv::EntryPointABIAttr entryPointInfo,
+                         ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
   // Set the attributes for argument and the function.
   StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
-  for (auto argIndex : llvm::seq<unsigned>(0, newFuncOp.getNumArguments())) {
-    newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
+  for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
+    funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
   }
-  newFuncOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
-  return newFuncOp;
+  funcOp.setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
+  return success();
 }