Simplify the classes that support SPIR-V conversion.
authorMahesh Ravishankar <ravishankarm@google.com>
Thu, 15 Aug 2019 17:54:22 +0000 (10:54 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 15 Aug 2019 17:54:46 +0000 (10:54 -0700)
Modify the Type converters to have a SPIRVBasicTypeConverter which
only handles conversion from standard types to SPIRV types. Rename
SPIRVEntryFnConverter to SPIRVTypeConverter. This contains the
SPIRVBasicTypeConverter within it.

Remove SPIRVFnLowering class and have separate utility methods to
lower a function as entry function or a non-entry function. The
current setup could end with diamond inheritence that is not very
friendly to use.  For example, you could define the following Op
conversion methods that lower from a dialect "Foo" which resuls in
diamond inheritance.

template<typename OpTy>
class FooDialect : public SPIRVOpLowering<OpTy> {...};
class FooFnLowering : public FooDialect, SPIRVFnLowering {...};

PiperOrigin-RevId: 263597101

mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp

index 21c2842..adfd83b 100644 (file)
@@ -33,12 +33,12 @@ class SPIRVDialect;
 }
 
 /// Type conversion from Standard Types to SPIR-V Types.
-class SPIRVTypeConverter : public TypeConverter {
+class SPIRVBasicTypeConverter : public TypeConverter {
 public:
-  explicit SPIRVTypeConverter(MLIRContext *context);
+  explicit SPIRVBasicTypeConverter(MLIRContext *context);
 
   /// Converts types to SPIR-V supported types.
-  Type convertType(Type t) override;
+  virtual Type convertType(Type t);
 
 protected:
   spirv::SPIRVDialect *spirvDialect;
@@ -47,51 +47,54 @@ protected:
 /// Converts a function type according to the requirements of a SPIR-V entry
 /// function. The arguments need to be converted to spv.Variables of spv.ptr
 /// types so that they could be bound by the runtime.
-class SPIRVEntryFnTypeConverter final : public SPIRVTypeConverter {
+class SPIRVTypeConverter final : public TypeConverter {
 public:
-  using SPIRVTypeConverter::SPIRVTypeConverter;
+  explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
+      : basicTypeConverter(basicTypeConverter) {}
+
+  /// Convert types to SPIR-V types using the basic type converter.
+  Type convertType(Type t) override {
+    return basicTypeConverter->convertType(t);
+  }
 
   /// Method to convert argument of a function. The `type` is converted to
   /// spv.ptr<type, Uniform>.
   // TODO(ravishankarm) : Support other storage classes.
   LogicalResult convertSignatureArg(unsigned inputNo, Type type,
                                     SignatureConversion &result) override;
+
+  /// Get the basic type converter.
+  SPIRVBasicTypeConverter *getBasicTypeConverter() const {
+    return basicTypeConverter;
+  }
+
+private:
+  SPIRVBasicTypeConverter *basicTypeConverter;
 };
 
 /// Base class to define a conversion pattern to translate Ops into SPIR-V.
 template <typename OpTy> class SPIRVOpLowering : public ConversionPattern {
 public:
-  SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter,
-                  SPIRVEntryFnTypeConverter &entryFnConverter)
+  SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter)
       : ConversionPattern(OpTy::getOperationName(), 1, context),
-        typeConverter(typeConverter), entryFnConverter(entryFnConverter) {}
+        typeConverter(typeConverter) {}
 
 protected:
   // Type lowering class.
   SPIRVTypeConverter &typeConverter;
-
-  // Entry function signature converter.
-  SPIRVEntryFnTypeConverter &entryFnConverter;
 };
 
-/// Base Class for legalize a FuncOp within a spv.module. This class can be
-/// extended to implement a ConversionPattern to lower a FuncOp. It provides
-/// hooks to legalize a FuncOp as a simple function, or as an entry function.
-class SPIRVFnLowering : public SPIRVOpLowering<FuncOp> {
-public:
-  using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
-
-protected:
-  /// Method to legalize the function as a non-entry function.
-  LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
-                              ConversionPatternRewriter &rewriter,
-                              FuncOp &newFuncOp) const;
-
-  /// Method to legalize the function as an entry function.
-  LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
-                                     ConversionPatternRewriter &rewriter,
-                                     FuncOp &newFuncOp) const;
-};
+/// Method to legalize a function as a non-entry function.
+LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+                            SPIRVTypeConverter *typeConverter,
+                            ConversionPatternRewriter &rewriter,
+                            FuncOp &newFuncOp);
+
+/// Method to legalize a function as an entry function.
+LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+                                   SPIRVTypeConverter *typeConverter,
+                                   ConversionPatternRewriter &rewriter,
+                                   FuncOp &newFuncOp);
 
 /// Appends to a pattern list additional patterns for translating StandardOps to
 /// SPIR-V ops.
index c36aee5..ff6af83 100644 (file)
@@ -31,9 +31,9 @@ namespace {
 
 /// Pattern to convert a kernel function in GPU dialect (a FuncOp with the
 /// attribute gpu.kernel) within a spv.module.
-class KernelFnConversion final : public SPIRVFnLowering {
+class KernelFnConversion final : public SPIRVOpLowering<FuncOp> {
 public:
-  using SPIRVFnLowering::SPIRVFnLowering;
+  using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
 
   PatternMatchResult
   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
@@ -47,12 +47,14 @@ KernelFnConversion::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   auto funcOp = cast<FuncOp>(op);
   FuncOp newFuncOp;
   if (!gpu::GPUDialect::isKernel(funcOp)) {
-    return succeeded(lowerFunction(funcOp, operands, rewriter, newFuncOp))
+    return succeeded(lowerFunction(funcOp, operands, &typeConverter, rewriter,
+                                   newFuncOp))
                ? matchSuccess()
                : matchFailure();
   }
 
-  if (failed(lowerAsEntryFunction(funcOp, operands, rewriter, newFuncOp))) {
+  if (failed(lowerAsEntryFunction(funcOp, operands, &typeConverter, rewriter,
+                                  newFuncOp))) {
     return matchFailure();
   }
   newFuncOp.getOperation()->removeAttr(Identifier::get(
@@ -101,16 +103,17 @@ void GPUToSPIRVPass::runOnModule() {
   }
 
   /// Dialect conversion to lower the functions with the spirv::ModuleOps.
-  SPIRVTypeConverter typeConverter(context);
-  SPIRVEntryFnTypeConverter entryFnConverter(context);
+  SPIRVBasicTypeConverter basicTypeConverter(context);
+  SPIRVTypeConverter typeConverter(&basicTypeConverter);
   OwningRewritePatternList patterns;
-  patterns.insert<KernelFnConversion>(context, typeConverter, entryFnConverter);
+  patterns.insert<KernelFnConversion>(context, typeConverter);
   populateStandardToSPIRVPatterns(context, patterns);
 
   ConversionTarget target(*context);
   target.addLegalDialect<spirv::SPIRVDialect>();
-  target.addDynamicallyLegalOp<FuncOp>(
-      [&](FuncOp Op) { return typeConverter.isSignatureLegal(Op.getType()); });
+  target.addDynamicallyLegalOp<FuncOp>([&](FuncOp Op) {
+    return basicTypeConverter.isSignatureLegal(Op.getType());
+  });
 
   if (failed(applyFullConversion(spirvModules, target, patterns,
                                  &typeConverter))) {
index 067f2ae..53a40df 100644 (file)
@@ -30,10 +30,10 @@ using namespace mlir;
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
-SPIRVTypeConverter::SPIRVTypeConverter(MLIRContext *context)
+SPIRVBasicTypeConverter::SPIRVBasicTypeConverter(MLIRContext *context)
     : spirvDialect(context->getRegisteredDialect<spirv::SPIRVDialect>()) {}
 
-Type SPIRVTypeConverter::convertType(Type t) {
+Type SPIRVBasicTypeConverter::convertType(Type t) {
   // Check if the type is SPIR-V supported. If so return the type.
   if (spirvDialect->isValidSPIRVType(t)) {
     return t;
@@ -58,10 +58,10 @@ Type SPIRVTypeConverter::convertType(Type t) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult
-SPIRVEntryFnTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
-                                               SignatureConversion &result) {
+SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
+                                        SignatureConversion &result) {
   // Try to convert the given input type.
-  auto convertedType = convertType(type);
+  auto convertedType = basicTypeConverter->convertType(type);
   // TODO(ravishankarm) : Vulkan spec requires these to be a
   // spirv::StructType. This is not a SPIR-V requirement, so just making this a
   // pointer type for now.
@@ -81,12 +81,10 @@ SPIRVEntryFnTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
   return success();
 }
 
-template <typename Converter>
-static LogicalResult
-lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
-                  ConversionPatternRewriter &rewriter, Converter &typeConverter,
-                  TypeConverter::SignatureConversion &signatureConverter,
-                  FuncOp &newFuncOp) {
+static LogicalResult lowerFunctionImpl(
+    FuncOp funcOp, ArrayRef<Value *> operands,
+    ConversionPatternRewriter &rewriter, TypeConverter *typeConverter,
+    TypeConverter::SignatureConversion &signatureConverter, FuncOp &newFuncOp) {
   auto fnType = funcOp.getType();
 
   if (fnType.getNumResults()) {
@@ -96,7 +94,7 @@ lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
 
   for (auto &argType : enumerate(fnType.getInputs())) {
     // Get the type of the argument
-    if (failed(typeConverter.convertSignatureArg(
+    if (failed(typeConverter->convertSignatureArg(
             argType.index(), argType.value(), signatureConverter))) {
       return funcOp.emitError("unable to convert argument type ")
              << argType.value() << " to SPIR-V type";
@@ -116,23 +114,25 @@ lowerFunctionImpl(FuncOp funcOp, ArrayRef<Value *> operands,
   return success();
 }
 
-LogicalResult
-SPIRVFnLowering::lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
-                               ConversionPatternRewriter &rewriter,
-                               FuncOp &newFuncOp) const {
+namespace mlir {
+LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+                            SPIRVTypeConverter *typeConverter,
+                            ConversionPatternRewriter &rewriter,
+                            FuncOp &newFuncOp) {
   auto fnType = funcOp.getType();
   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
-  return lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
+  return lowerFunctionImpl(funcOp, operands, rewriter,
+                           typeConverter->getBasicTypeConverter(),
                            signatureConverter, newFuncOp);
 }
 
-LogicalResult
-SPIRVFnLowering::lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
-                                      ConversionPatternRewriter &rewriter,
-                                      FuncOp &newFuncOp) const {
+LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
+                                   SPIRVTypeConverter *typeConverter,
+                                   ConversionPatternRewriter &rewriter,
+                                   FuncOp &newFuncOp) {
   auto fnType = funcOp.getType();
   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
-  if (failed(lowerFunctionImpl(funcOp, operands, rewriter, entryFnConverter,
+  if (failed(lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
                                signatureConverter, newFuncOp))) {
     return failure();
   }
@@ -167,6 +167,7 @@ SPIRVFnLowering::lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
       builder.getSymbolRefAttr(newFuncOp.getName()), interface);
   return success();
 }
+} // namespace mlir
 
 //===----------------------------------------------------------------------===//
 // Operation conversion