Add operations needed to support lowering of AffineExpr to SPIR-V.
authorMahesh Ravishankar <ravishankarm@google.com>
Tue, 12 Nov 2019 21:19:33 +0000 (13:19 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 12 Nov 2019 21:20:06 +0000 (13:20 -0800)
Lowering of CmpIOp, DivISOp, RemISOp, SubIOp and SelectOp to SPIR-V
dialect enables the lowering of operations generated by AffineExpr ->
StandardOps conversion into the SPIR-V dialect.

PiperOrigin-RevId: 280039204

mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/test/Conversion/StandardToSPIRV/op_conversion.mlir

index 56b243c..6bb9dea 100644 (file)
@@ -314,8 +314,9 @@ public:
       return matchFailure();
     }
 
-    // Use the bitwidth set in the value attribute to decide the result type of
-    // the SPIR-V constant operation since SPIR-V does not support index types.
+    // Use the bitwidth set in the value attribute to decide the result type
+    // of the SPIR-V constant operation since SPIR-V does not support index
+    // types.
     auto constVal = constAttr.getValue();
     auto constValType = constAttr.getType().dyn_cast<IndexType>();
     if (!constValType) {
@@ -331,11 +332,47 @@ public:
   }
 };
 
-/// Convert integer binary operations to SPIR-V operations. Cannot use tablegen
-/// for this. If the integer operation is on variables of IndexType, the type of
-/// the return value of the replacement operation differs from that of the
-/// replaced operation. This is not handled in tablegen-based pattern
-/// specification.
+/// Convert compare operation to SPIR-V dialect.
+class CmpIOpConversion final : public ConversionPattern {
+public:
+  CmpIOpConversion(MLIRContext *context)
+      : ConversionPattern(CmpIOp::getOperationName(), 1, context) {}
+
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto cmpIOp = cast<CmpIOp>(op);
+    CmpIOpOperandAdaptor cmpIOpOperands(operands);
+
+    switch (cmpIOp.getPredicate()) {
+#define DISPATCH(cmpPredicate, spirvOp)                                        \
+  case cmpPredicate:                                                           \
+    rewriter.replaceOpWithNewOp<spirvOp>(op, op->getResult(0)->getType(),      \
+                                         cmpIOpOperands.lhs(),                 \
+                                         cmpIOpOperands.rhs());                \
+    return matchSuccess();
+
+      DISPATCH(CmpIPredicate::EQ, spirv::IEqualOp);
+      DISPATCH(CmpIPredicate::NE, spirv::INotEqualOp);
+      DISPATCH(CmpIPredicate::SLT, spirv::SLessThanOp);
+      DISPATCH(CmpIPredicate::SLE, spirv::SLessThanEqualOp);
+      DISPATCH(CmpIPredicate::SGT, spirv::SGreaterThanOp);
+      DISPATCH(CmpIPredicate::SGE, spirv::SGreaterThanEqualOp);
+
+#undef DISPATCH
+
+    default:
+      break;
+    }
+    return matchFailure();
+  }
+};
+
+/// Convert integer binary operations to SPIR-V operations. Cannot use
+/// tablegen for this. If the integer operation is on variables of IndexType,
+/// the type of the return value of the replacement operation differs from
+/// that of the replaced operation. This is not handled in tablegen-based
+/// pattern specification.
 template <typename StdOp, typename SPIRVOp>
 class IntegerOpConversion final : public ConversionPattern {
 public:
@@ -396,9 +433,25 @@ public:
   }
 };
 
-/// Convert store -> spv.StoreOp. The operands of the replaced operation are of
-/// IndexType while that of the replacement operation are of type i32. This is
-/// not supported in tablegen based pattern specification.
+/// Convert select -> spv.Select
+class SelectOpConversion : public ConversionPattern {
+public:
+  SelectOpConversion(MLIRContext *context)
+      : ConversionPattern(SelectOp::getOperationName(), 1, context) {}
+  PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    SelectOpOperandAdaptor selectOperands(operands);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
+                                                 selectOperands.true_value(),
+                                                 selectOperands.false_value());
+    return matchSuccess();
+  }
+};
+
+/// Convert store -> spv.StoreOp. The operands of the replaced operation are
+/// of IndexType while that of the replacement operation are of type i32. This
+/// is not supported in tablegen based pattern specification.
 // TODO(ravishankarm) : These could potentially be templated on the operation
 // being converted, since the same logic should work for linalg.store.
 class StoreOpConversion final : public ConversionPattern {
@@ -437,9 +490,14 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      OwningRewritePatternList &patterns) {
   populateWithGenerated(context, &patterns);
   // Add the return op conversion.
-  patterns.insert<ConstantIndexOpConversion,
-                  IntegerOpConversion<AddIOp, spirv::IAddOp>,
-                  IntegerOpConversion<MulIOp, spirv::IMulOp>, LoadOpConversion,
-                  ReturnToSPIRVConversion, StoreOpConversion>(context);
+  patterns
+      .insert<ConstantIndexOpConversion, CmpIOpConversion,
+              IntegerOpConversion<AddIOp, spirv::IAddOp>,
+              IntegerOpConversion<MulIOp, spirv::IMulOp>,
+              IntegerOpConversion<DivISOp, spirv::SDivOp>,
+              IntegerOpConversion<RemISOp, spirv::SModOp>,
+              IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
+              ReturnToSPIRVConversion, SelectOpConversion, StoreOpConversion>(
+          context);
 }
 } // namespace mlir
index 334920c..d0effdd 100644 (file)
@@ -57,4 +57,47 @@ func @constval() {
   // CHECK: spv.constant 1 : i32
   %4 = constant 1 : index
   return
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: @cmpiop
+func @cmpiop(%arg0 : i32, %arg1 : i32) {
+  // CHECK: spv.IEqual
+  %0 = cmpi "eq", %arg0, %arg1 : i32
+  // CHECK: spv.INotEqual
+  %1 = cmpi "ne", %arg0, %arg1 : i32
+  // CHECK: spv.SLessThan
+  %2 = cmpi "slt", %arg0, %arg1 : i32
+  // CHECK: spv.SLessThanEqual
+  %3 = cmpi "sle", %arg0, %arg1 : i32
+  // CHECK: spv.SGreaterThan
+  %4 = cmpi "sgt", %arg0, %arg1 : i32
+  // CHECK: spv.SGreaterThanEqual
+  %5 = cmpi "sge", %arg0, %arg1 : i32
+  return
+}
+
+// CHECK-LABEL: @select
+func @selectOp(%arg0 : i32, %arg1 : i32) {
+  %0 = cmpi "sle", %arg0, %arg1 : i32
+  // CHECK: spv.Select
+  %1 = select %0, %arg0, %arg1 : i32
+  return
+}
+
+// CHECK-LABEL: @div_rem
+func @div_rem(%arg0 : i32, %arg1 : i32) {
+  // CHECK: spv.SDiv
+  %0 = divis %arg0, %arg1 : i32
+  // CHECK: spv.SMod
+  %1 = remis %arg0, %arg1 : i32
+  return
+}
+
+// CHECK-LABEL: @add_sub
+func @add_sub(%arg0 : i32, %arg1 : i32) {
+  // CHECK: spv.IAdd
+  %0 = addi %arg0, %arg1 : i32
+  // CHECK: spv.ISub
+  %1 = subi %arg0, %arg1 : i32
+  return
+}