Add lowering of constant ops to SPIR-V.
authorMahesh Ravishankar <ravishankarm@google.com>
Thu, 10 Oct 2019 22:51:35 +0000 (15:51 -0700)
committerJacques Pienaar <jpienaar@google.com>
Fri, 11 Oct 2019 00:19:57 +0000 (17:19 -0700)
The lowering is specified as a pattern and is done only if the result
is a SPIR-V scalar type or vector type.
Handling ConstantOp with index return type needs special handling
since SPIR-V dialect does not have index types. Based on the bitwidth
of the attribute value, either i32 or i64 is chosen.
Other constant lowerings are left as a TODO.

PiperOrigin-RevId: 274056805

mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.td
mlir/test/Conversion/StandardToSPIRV/op_conversion.mlir [moved from mlir/test/Dialect/SPIRV/standard_ops_to_spirv.mlir with 74% similarity]

index b104b53..a607ff5 100644 (file)
@@ -31,6 +31,21 @@ using namespace mlir;
 // Type Conversion
 //===----------------------------------------------------------------------===//
 
+static Type convertIndexType(MLIRContext *context) {
+  // Convert to 32-bit integers for now. Might need a way to control this in
+  // future.
+  // TODO(ravishankarm): It is porbably better to make it 64-bit integers. To
+  // this some support is needed in SPIR-V dialect for Conversion
+  // instructions. The Vulkan spec requires the builtins like
+  // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
+  // SExtended to 64-bit for index computations.
+  return IntegerType::get(32, context);
+}
+
+static Type convertIndexType(IndexType t) {
+  return convertIndexType(t.getContext());
+}
+
 static Type basicTypeConversion(Type t) {
   // Check if the type is SPIR-V supported. If so return the type.
   if (spirv::SPIRVDialect::isValidType(t)) {
@@ -38,8 +53,7 @@ static Type basicTypeConversion(Type t) {
   }
 
   if (auto indexType = t.dyn_cast<IndexType>()) {
-    // Return I32 for index types.
-    return IntegerType::get(32, t.getContext());
+    return convertIndexType(indexType);
   }
 
   if (auto memRefType = t.dyn_cast<MemRefType>()) {
@@ -122,9 +136,9 @@ static Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter,
   // Insert the addressOf and load instructions, to get back the converted value
   // type.
   auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var);
-  auto zero = rewriter.create<spirv::ConstantOp>(funcOp.getLoc(),
-                                                 rewriter.getIntegerType(32),
-                                                 rewriter.getI32IntegerAttr(0));
+  auto indexType = convertIndexType(funcOp.getContext());
+  auto zero = rewriter.create<spirv::ConstantOp>(
+      funcOp.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0));
   auto accessChain = rewriter.create<spirv::AccessChainOp>(
       funcOp.getLoc(), addressOf.pointer(), zero.constant());
   // If the original argument is a tensor/memref type, the value is not
@@ -269,6 +283,46 @@ LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder) {
 
 namespace {
 
+/// Convert constant operation with IndexType return to SPIR-V constant
+/// operation. Since IndexType is not used within SPIR-V dialect, this needs
+/// special handling to make sure the result type and the type of the value
+/// attribute are consistent.
+class ConstantIndexOpConversion final : public ConversionPattern {
+public:
+  ConstantIndexOpConversion(MLIRContext *context)
+      : ConversionPattern(ConstantOp::getOperationName(), 1, context) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto constIndexOp = cast<ConstantOp>(op);
+    if (!constIndexOp.getResult()->getType().isa<IndexType>()) {
+      return matchFailure();
+    }
+    // The attribute has index type. Get the integer value and create a new
+    // IntegerAttr.
+    auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
+    if (!constAttr) {
+      return matchFailure();
+    }
+
+    // Use the bitwidth set in the value attribute to decide the result type of
+    // the SPIR-V constant operation since SPIR-V does not support index types.
+    auto constVal = constAttr.getValue();
+    auto constValType = constAttr.getType().dyn_cast<IndexType>();
+    if (!constValType) {
+      return matchFailure();
+    }
+    auto spirvConstType = convertIndexType(constValType);
+    auto spirvConstVal =
+        rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
+    auto spirvConstantOp = rewriter.create<spirv::ConstantOp>(
+        op->getLoc(), spirvConstType, spirvConstVal);
+    rewriter.replaceOp(op, spirvConstantOp.constant(), {});
+    return matchSuccess();
+  }
+};
+
 /// Convert integer binary operations to SPIR-V operations. Cannot use tablegen
 /// for this. If the integer operation is on variables of IndexType, the type of
 /// the return value of the replacement operation differs from that of the
@@ -375,7 +429,8 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      OwningRewritePatternList &patterns) {
   populateWithGenerated(context, &patterns);
   // Add the return op conversion.
-  patterns.insert<IntegerOpConversion<AddIOp, spirv::IAddOp>,
+  patterns.insert<ConstantIndexOpConversion,
+                  IntegerOpConversion<AddIOp, spirv::IAddOp>,
                   IntegerOpConversion<MulIOp, spirv::IMulOp>, LoadOpConversion,
                   ReturnToSPIRVConversion, StoreOpConversion>(context);
 }
index 5edba26..d9b217b 100644 (file)
@@ -31,4 +31,10 @@ class BinaryOpPattern<Op src, Op tgt> :
 def : BinaryOpPattern<AddFOp, SPV_FAddOp>;
 def : BinaryOpPattern<MulFOp, SPV_FMulOp>;
 
+// Constant Op
+// TODO(ravishankarm): Handle lowering other constant types.
+def : Pat<(ConstantOp:$result $valueAttr),
+          (SPV_ConstantOp $valueAttr),
+          [(SPV_ScalarOrVector $result)]>;
+
 #endif // MLIR_CONVERSION_STANDARDTOSPIRV_TD
@@ -44,3 +44,17 @@ func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
   return %0 : tensor<4xf32>
 }
 
+// CHECK-LABEL: @constval
+func @constval() {
+  // CHECK: spv.constant true
+  %0 = constant true
+  // CHECK: spv.constant 42 : i64
+  %1 = constant 42
+  // CHECK: spv.constant {{[0-9]*\.[0-9]*e?-?[0-9]*}} : f32
+  %2 = constant 0.5 : f32
+  // CHECK: spv.constant dense<[2, 3]> : vector<2xi32>
+  %3 = constant dense<[2, 3]> : vector<2xi32>
+  // CHECK: spv.constant 1 : i32
+  %4 = constant 1 : index
+  return
+}
\ No newline at end of file