Add folding rule and dialect materialization hook for spv.constant
authorLei Zhang <antiagainst@google.com>
Tue, 3 Sep 2019 19:09:07 +0000 (12:09 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 3 Sep 2019 19:09:58 +0000 (12:09 -0700)
This will allow us to use MLIR's folding infrastructure to deduplicate
SPIR-V constants.

This CL also changed isValidSPIRVType in SPIRVDialect to a static method.

PiperOrigin-RevId: 266984403

mlir/include/mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h
mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/canonicalize.mlir [new file with mode: 0644]
mlir/test/Dialect/SPIRV/structure-ops.mlir

index 25a710f..e31c58c 100644 (file)
@@ -33,20 +33,12 @@ namespace mlir {
 class LoadOp;
 class ReturnOp;
 class StoreOp;
-namespace spirv {
-class SPIRVDialect;
-}
 
 /// Type conversion from Standard Types to SPIR-V Types.
 class SPIRVBasicTypeConverter : public TypeConverter {
 public:
-  explicit SPIRVBasicTypeConverter(MLIRContext *context);
-
   /// Converts types to SPIR-V supported types.
   virtual Type convertType(Type t);
-
-protected:
-  spirv::SPIRVDialect *spirvDialect;
 };
 
 /// Converts a function type according to the requirements of a SPIR-V entry
index 494adc1..3086d8d 100644 (file)
@@ -33,14 +33,18 @@ public:
 
   static StringRef getDialectNamespace() { return "spv"; }
 
+  /// Checks if the given `type` is valid in SPIR-V dialect.
+  static bool isValidType(Type type);
+
   /// Parses a type registered to this dialect.
   Type parseType(llvm::StringRef spec, Location loc) const override;
 
   /// Prints a type registered to this dialect.
   void printType(Type type, llvm::raw_ostream &os) const override;
 
-  /// Checks if a type is valid in SPIR-V dialect.
-  bool isValidSPIRVType(Type t) const;
+  /// Provides a hook for materializing a constant to this dialect.
+  Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
+                                 Location loc) override;
 };
 
 } // end namespace spirv
index 2073fa4..8b2eb16 100644 (file)
@@ -112,6 +112,13 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
     SPV_Type:$constant
   );
 
+  let hasFolder = 1;
+
+  let extraClassDeclaration = [{
+    // Returns true if a constant can be built for the given `type`.
+    static bool isBuildableWith(Type type);
+  }];
+
   let hasOpcode = 0;
 }
 
index b25c7a3..d261dc2 100644 (file)
@@ -49,6 +49,7 @@ enum Kind {
   Pointer,
   RuntimeArray,
   Struct,
+  LAST_SPIRV_TYPE = Struct,
 };
 }
 
index 06b2498..6746594 100644 (file)
@@ -142,7 +142,7 @@ void GPUToSPIRVPass::runOnModule() {
   }
 
   /// Dialect conversion to lower the functions with the spirv::ModuleOps.
-  SPIRVBasicTypeConverter basicTypeConverter(context);
+  SPIRVBasicTypeConverter basicTypeConverter;
   SPIRVTypeConverter typeConverter(&basicTypeConverter);
   OwningRewritePatternList patterns;
   patterns.insert<
index e3bcc04..fe2fbd0 100644 (file)
@@ -30,12 +30,9 @@ using namespace mlir;
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
-SPIRVBasicTypeConverter::SPIRVBasicTypeConverter(MLIRContext *context)
-    : spirvDialect(context->getRegisteredDialect<spirv::SPIRVDialect>()) {}
-
 Type SPIRVBasicTypeConverter::convertType(Type t) {
   // Check if the type is SPIR-V supported. If so return the type.
-  if (spirvDialect->isValidSPIRVType(t)) {
+  if (spirv::SPIRVDialect::isValidType(t)) {
     return t;
   }
 
index 85a6a62..b0f3f50 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Parser.h"
@@ -104,9 +105,10 @@ static bool isValidSPIRVVectorType(VectorType type) {
          type.getNumElements() >= 2 && type.getNumElements() <= 4;
 }
 
-bool SPIRVDialect::isValidSPIRVType(Type type) const {
+bool SPIRVDialect::isValidType(Type type) {
   // Allow SPIR-V dialect types
-  if (&type.getDialect() == this) {
+  if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
+      type.getKind() <= TypeKind::LAST_SPIRV_TYPE) {
     return true;
   }
   if (isValidSPIRVScalarType(type)) {
@@ -633,3 +635,16 @@ void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const {
     llvm_unreachable("unhandled SPIR-V type");
   }
 }
+
+//===----------------------------------------------------------------------===//
+// Constant
+//===----------------------------------------------------------------------===//
+
+Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
+                                             Attribute value, Type type,
+                                             Location loc) {
+  if (!ConstantOp::isBuildableWith(type))
+    return nullptr;
+
+  return builder.create<spirv::ConstantOp>(loc, type, value);
+}
index 66b2f5d..3338c10 100644 (file)
@@ -754,8 +754,10 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
     auto elemType = arrayType.getElementType();
     for (auto element : value.cast<ArrayAttr>().getValue()) {
       if (element.getType() != elemType)
-        return constOp.emitOpError(
-            "has array element that are not of result array element type");
+        return constOp.emitOpError("has array element whose type (")
+               << element.getType()
+               << ") does not match the result element type (" << elemType
+               << ')';
     }
   } break;
   default:
@@ -765,6 +767,25 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
   return success();
 }
 
+OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.empty() && "constant has no operands");
+  return value();
+}
+
+bool spirv::ConstantOp::isBuildableWith(Type type) {
+  // Must be valid SPIR-V type first.
+  if (!SPIRVDialect::isValidType(type))
+    return false;
+
+  if (type.getKind() >= Type::FIRST_SPIRV_TYPE &&
+      type.getKind() <= spirv::TypeKind::LAST_SPIRV_TYPE) {
+    // TODO(antiagainst): support contant struct
+    return type.isa<spirv::ArrayType>();
+  }
+
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // spv.EntryPoint
 //===----------------------------------------------------------------------===//
@@ -1350,8 +1371,7 @@ static LogicalResult verify(spirv::SpecConstantOp constOp) {
   case StandardAttributes::Integer:
   case StandardAttributes::Float: {
     // Make sure bitwidth is allowed.
-    auto *dialect = static_cast<spirv::SPIRVDialect *>(constOp.getDialect());
-    if (!dialect->isValidSPIRVType(value.getType()))
+    if (!spirv::SPIRVDialect::isValidType(value.getType()))
       return constOp.emitOpError("default value bitwidth disallowed");
     return success();
   }
diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir
new file mode 100644 (file)
index 0000000..24e4e99
--- /dev/null
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.constant
+//===----------------------------------------------------------------------===//
+
+// TODO(antiagainst): test constants in different blocks
+
+func @deduplicate_scalar_constant() -> (i32, i32) {
+  // CHECK: %[[CST:.*]] = spv.constant 42 : i32
+  %0 = spv.constant 42 : i32
+  %1 = spv.constant 42 : i32
+  // CHECK-NEXT: return %[[CST]], %[[CST]]
+  return %0, %1 : i32, i32
+}
+
+// -----
+
+func @deduplicate_vector_constant() -> (vector<3xi32>, vector<3xi32>) {
+  // CHECK: %[[CST:.*]] = spv.constant dense<[1, 2, 3]> : vector<3xi32>
+  %0 = spv.constant dense<[1, 2, 3]> : vector<3xi32>
+  %1 = spv.constant dense<[1, 2, 3]> : vector<3xi32>
+  // CHECK-NEXT: return %[[CST]], %[[CST]]
+  return %0, %1 : vector<3xi32>, vector<3xi32>
+}
+
+// -----
+
+func @deduplicate_composite_constant() -> (!spv.array<1 x vector<2xi32>>, !spv.array<1 x vector<2xi32>>) {
+  // CHECK: %[[CST:.*]] = spv.constant [dense<5> : vector<2xi32>] : !spv.array<1 x vector<2xi32>>
+  %0 = spv.constant [dense<5> : vector<2xi32>] : !spv.array<1 x vector<2xi32>>
+  %1 = spv.constant [dense<5> : vector<2xi32>] : !spv.array<1 x vector<2xi32>>
+  // CHECK-NEXT: return %[[CST]], %[[CST]]
+  return %0, %1 : !spv.array<1 x vector<2xi32>>, !spv.array<1 x vector<2xi32>>
+}
index b8d75ed..356855d 100644 (file)
@@ -68,7 +68,7 @@ func @unaccepted_std_attr() -> () {
 // -----
 
 func @array_constant() -> () {
-  // expected-error @+1 {{has array element that are not of result array element type}}
+  // expected-error @+1 {{has array element whose type ('vector<2xi32>') does not match the result element type ('vector<2xf32>')}}
   %0 = spv.constant [dense<3.0> : vector<2xf32>, dense<4> : vector<2xi32>] : !spv.array<2xvector<2xf32>>
   return
 }