NFC: Separate implementation and definition in ConvertStandardToSPIRV.cpp
authorMahesh Ravishankar <ravishankarm@google.com>
Fri, 6 Dec 2019 23:25:46 +0000 (15:25 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Dec 2019 23:26:17 +0000 (15:26 -0800)
PiperOrigin-RevId: 284274326

mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp

index c2ca4c9..e87bd4e 100644 (file)
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
-// Utility functions for operation conversion
-//===----------------------------------------------------------------------===//
-
-/// Performs the index computation to get to the element pointed to by
-/// `indices` using the layout map of `baseType`.
-
-// TODO(ravishankarm) : This method assumes that the `origBaseType` is a
-// MemRefType with AffineMap that has static strides. Handle dynamic strides
-spirv::AccessChainOp getElementPtr(OpBuilder &builder,
-                                   SPIRVTypeConverter &typeConverter,
-                                   Location loc, MemRefType origBaseType,
-                                   Value *basePtr, ArrayRef<Value *> indices) {
-  // Get base and offset of the MemRefType and verify they are static.
-  int64_t offset;
-  SmallVector<int64_t, 4> strides;
-  if (failed(getStridesAndOffset(origBaseType, strides, offset)) ||
-      llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
-    return nullptr;
-  }
-
-  auto indexType = typeConverter.getIndexType(builder.getContext());
-
-  Value *ptrLoc = nullptr;
-  assert(indices.size() == strides.size());
-  for (auto index : enumerate(indices)) {
-    Value *strideVal = builder.create<spirv::ConstantOp>(
-        loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
-    Value *update =
-        builder.create<spirv::IMulOp>(loc, strideVal, index.value());
-    ptrLoc =
-        (ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult()
-                : update);
-  }
-  SmallVector<Value *, 2> linearizedIndices;
-  // Add a '0' at the start to index into the struct.
-  linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
-      loc, indexType, IntegerAttr::get(indexType, 0)));
-  linearizedIndices.push_back(ptrLoc);
-  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
-}
-
-//===----------------------------------------------------------------------===//
 // Operation conversion
 //===----------------------------------------------------------------------===//
 
@@ -87,33 +45,7 @@ public:
 
   PatternMatchResult
   matchAndRewrite(ConstantOp constIndexOp, ArrayRef<Value *> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    if (!constIndexOp.getResult()->getType().isa<IndexType>()) {
-      return matchFailure();
-    }
-    // The attribute has index type which is not directly supported in
-    // SPIR-V. Get the integer value and create a new IntegerAttr.
-    auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
-    if (!constAttr) {
-      return matchFailure();
-    }
-
-    // Use the bitwidth set in the value attribute to decide the result type
-    // of the SPIR-V constant operation since SPIR-V does not support index
-    // types.
-    auto constVal = constAttr.getValue();
-    auto constValType = constAttr.getType().dyn_cast<IndexType>();
-    if (!constValType) {
-      return matchFailure();
-    }
-    auto spirvConstType =
-        typeConverter.convertType(constIndexOp.getResult()->getType());
-    auto spirvConstVal =
-        rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
-    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
-                                                   spirvConstVal);
-    return matchSuccess();
-  }
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Convert compare operation to SPIR-V dialect.
@@ -123,31 +55,7 @@ public:
 
   PatternMatchResult
   matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value *> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    CmpIOpOperandAdaptor cmpIOpOperands(operands);
-
-    switch (cmpIOp.getPredicate()) {
-#define DISPATCH(cmpPredicate, spirvOp)                                        \
-  case cmpPredicate:                                                           \
-    rewriter.replaceOpWithNewOp<spirvOp>(                                      \
-        cmpIOp, cmpIOp.getResult()->getType(), cmpIOpOperands.lhs(),           \
-        cmpIOpOperands.rhs());                                                 \
-    return matchSuccess();
-
-      DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
-      DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
-      DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
-      DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
-      DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
-      DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
-
-#undef DISPATCH
-
-    default:
-      break;
-    }
-    return matchFailure();
-  }
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Convert integer binary operations to SPIR-V operations. Cannot use
@@ -182,33 +90,18 @@ public:
 
   PatternMatchResult
   matchAndRewrite(LoadOp loadOp, ArrayRef<Value *> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    LoadOpOperandAdaptor loadOperands(operands);
-    auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(),
-                                 loadOp.memref()->getType().cast<MemRefType>(),
-                                 loadOperands.memref(), loadOperands.indices());
-    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr,
-                                               /*memory_access =*/nullptr,
-                                               /*alignment =*/nullptr);
-    return matchSuccess();
-  }
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Convert return -> spv.Return.
 // TODO(ravishankarm) : This should be moved into DRR.
-class ReturnToSPIRVConversion final : public SPIRVOpLowering<ReturnOp> {
+class ReturnOpConversion final : public SPIRVOpLowering<ReturnOp> {
 public:
   using SPIRVOpLowering<ReturnOp>::SPIRVOpLowering;
 
   PatternMatchResult
   matchAndRewrite(ReturnOp returnOp, ArrayRef<Value *> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    if (returnOp.getNumOperands()) {
-      return matchFailure();
-    }
-    rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
-    return matchSuccess();
-  }
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Convert select -> spv.Select
@@ -218,13 +111,7 @@ public:
   using SPIRVOpLowering<SelectOp>::SPIRVOpLowering;
   PatternMatchResult
   matchAndRewrite(SelectOp op, ArrayRef<Value *> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    SelectOpOperandAdaptor selectOperands(operands);
-    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
-                                                 selectOperands.true_value(),
-                                                 selectOperands.false_value());
-    return matchSuccess();
-  }
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 /// Convert store -> spv.StoreOp. The operands of the replaced operation are
@@ -237,22 +124,184 @@ public:
 
   PatternMatchResult
   matchAndRewrite(StoreOp storeOp, ArrayRef<Value *> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    StoreOpOperandAdaptor storeOperands(operands);
-    auto storePtr =
-        getElementPtr(rewriter, typeConverter, storeOp.getLoc(),
-                      storeOp.memref()->getType().cast<MemRefType>(),
-                      storeOperands.memref(), storeOperands.indices());
-    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
-                                                storeOperands.value(),
-                                                /*memory_access =*/nullptr,
-                                                /*alignment =*/nullptr);
-    return matchSuccess();
-  }
+                  ConversionPatternRewriter &rewriter) const override;
 };
 
 } // namespace
 
+//===----------------------------------------------------------------------===//
+// Utility functions for operation conversion
+//===----------------------------------------------------------------------===//
+
+/// Performs the index computation to get to the element pointed to by
+/// `indices` using the layout map of `baseType`.
+
+// TODO(ravishankarm) : This method assumes that the `origBaseType` is a
+// MemRefType with AffineMap that has static strides. Handle dynamic strides
+spirv::AccessChainOp getElementPtr(OpBuilder &builder,
+                                   SPIRVTypeConverter &typeConverter,
+                                   Location loc, MemRefType origBaseType,
+                                   Value *basePtr, ArrayRef<Value *> indices) {
+  // Get base and offset of the MemRefType and verify they are static.
+  int64_t offset;
+  SmallVector<int64_t, 4> strides;
+  if (failed(getStridesAndOffset(origBaseType, strides, offset)) ||
+      llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
+    return nullptr;
+  }
+
+  auto indexType = typeConverter.getIndexType(builder.getContext());
+
+  Value *ptrLoc = nullptr;
+  assert(indices.size() == strides.size());
+  for (auto index : enumerate(indices)) {
+    Value *strideVal = builder.create<spirv::ConstantOp>(
+        loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
+    Value *update =
+        builder.create<spirv::IMulOp>(loc, strideVal, index.value());
+    ptrLoc =
+        (ptrLoc ? builder.create<spirv::IAddOp>(loc, ptrLoc, update).getResult()
+                : update);
+  }
+  SmallVector<Value *, 2> linearizedIndices;
+  // Add a '0' at the start to index into the struct.
+  linearizedIndices.push_back(builder.create<spirv::ConstantOp>(
+      loc, indexType, IntegerAttr::get(indexType, 0)));
+  linearizedIndices.push_back(ptrLoc);
+  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
+}
+
+//===----------------------------------------------------------------------===//
+// ConstantOp with index type.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult ConstantIndexOpConversion::matchAndRewrite(
+    ConstantOp constIndexOp, ArrayRef<Value *> operands,
+    ConversionPatternRewriter &rewriter) const {
+  if (!constIndexOp.getResult()->getType().isa<IndexType>()) {
+    return matchFailure();
+  }
+  // The attribute has index type which is not directly supported in
+  // SPIR-V. Get the integer value and create a new IntegerAttr.
+  auto constAttr = constIndexOp.value().dyn_cast<IntegerAttr>();
+  if (!constAttr) {
+    return matchFailure();
+  }
+
+  // Use the bitwidth set in the value attribute to decide the result type
+  // of the SPIR-V constant operation since SPIR-V does not support index
+  // types.
+  auto constVal = constAttr.getValue();
+  auto constValType = constAttr.getType().dyn_cast<IndexType>();
+  if (!constValType) {
+    return matchFailure();
+  }
+  auto spirvConstType =
+      typeConverter.convertType(constIndexOp.getResult()->getType());
+  auto spirvConstVal =
+      rewriter.getIntegerAttr(spirvConstType, constAttr.getInt());
+  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constIndexOp, spirvConstType,
+                                                 spirvConstVal);
+  return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// CmpIOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value *> operands,
+                                  ConversionPatternRewriter &rewriter) const {
+  CmpIOpOperandAdaptor cmpIOpOperands(operands);
+
+  switch (cmpIOp.getPredicate()) {
+#define DISPATCH(cmpPredicate, spirvOp)                                        \
+  case cmpPredicate:                                                           \
+    rewriter.replaceOpWithNewOp<spirvOp>(                                      \
+        cmpIOp, cmpIOp.getResult()->getType(), cmpIOpOperands.lhs(),           \
+        cmpIOpOperands.rhs());                                                 \
+    return matchSuccess();
+
+    DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
+    DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
+    DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
+    DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
+    DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
+    DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
+
+#undef DISPATCH
+
+  default:
+    break;
+  }
+  return matchFailure();
+}
+
+//===----------------------------------------------------------------------===//
+// LoadOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef<Value *> operands,
+                                  ConversionPatternRewriter &rewriter) const {
+  LoadOpOperandAdaptor loadOperands(operands);
+  auto loadPtr = getElementPtr(rewriter, typeConverter, loadOp.getLoc(),
+                               loadOp.memref()->getType().cast<MemRefType>(),
+                               loadOperands.memref(), loadOperands.indices());
+  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr,
+                                             /*memory_access =*/nullptr,
+                                             /*alignment =*/nullptr);
+  return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// ReturnOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+ReturnOpConversion::matchAndRewrite(ReturnOp returnOp,
+                                    ArrayRef<Value *> operands,
+                                    ConversionPatternRewriter &rewriter) const {
+  if (returnOp.getNumOperands()) {
+    return matchFailure();
+  }
+  rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
+  return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef<Value *> operands,
+                                    ConversionPatternRewriter &rewriter) const {
+  SelectOpOperandAdaptor selectOperands(operands);
+  rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
+                                               selectOperands.true_value(),
+                                               selectOperands.false_value());
+  return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
+// StoreOp
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult
+StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef<Value *> operands,
+                                   ConversionPatternRewriter &rewriter) const {
+  StoreOpOperandAdaptor storeOperands(operands);
+  auto storePtr =
+      getElementPtr(rewriter, typeConverter, storeOp.getLoc(),
+                    storeOp.memref()->getType().cast<MemRefType>(),
+                    storeOperands.memref(), storeOperands.indices());
+  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
+                                              storeOperands.value(),
+                                              /*memory_access =*/nullptr,
+                                              /*alignment =*/nullptr);
+  return matchSuccess();
+}
+
 namespace {
 /// Import the Standard Ops to SPIR-V Patterns.
 #include "StandardToSPIRV.cpp.inc"
@@ -264,14 +313,13 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      OwningRewritePatternList &patterns) {
   // Add patterns that lower operations into SPIR-V dialect.
   populateWithGenerated(context, &patterns);
-  patterns
-      .insert<ConstantIndexOpConversion, CmpIOpConversion,
-              IntegerOpConversion<AddIOp, spirv::IAddOp>,
-              IntegerOpConversion<MulIOp, spirv::IMulOp>,
-              IntegerOpConversion<DivISOp, spirv::SDivOp>,
-              IntegerOpConversion<RemISOp, spirv::SModOp>,
-              IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
-              ReturnToSPIRVConversion, SelectOpConversion, StoreOpConversion>(
-          context, typeConverter);
+  patterns.insert<ConstantIndexOpConversion, CmpIOpConversion,
+                  IntegerOpConversion<AddIOp, spirv::IAddOp>,
+                  IntegerOpConversion<MulIOp, spirv::IMulOp>,
+                  IntegerOpConversion<DivISOp, spirv::SDivOp>,
+                  IntegerOpConversion<RemISOp, spirv::SModOp>,
+                  IntegerOpConversion<SubIOp, spirv::ISubOp>, LoadOpConversion,
+                  ReturnOpConversion, SelectOpConversion, StoreOpConversion>(
+      context, typeConverter);
 }
 } // namespace mlir