[mlir][spirv] Add lowering for composite std.constant.
authorDenis Khalikov <khalikov.denis@huawei.com>
Wed, 22 Jan 2020 13:05:27 +0000 (08:05 -0500)
committerLei Zhang <antiagainst@google.com>
Wed, 22 Jan 2020 13:25:00 +0000 (08:25 -0500)
Add lowering for constant operation with ranked tensor type to
spv.constant with spv.array type.

Differential Revision: https://reviews.llvm.org/D73022

mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/test/Conversion/StandardToSPIRV/std-to-spirv.mlir

index cc535b9..4fcb651 100644 (file)
@@ -25,6 +25,17 @@ using namespace mlir;
 
 namespace {
 
+/// Convert composite constant operation to SPIR-V dialect.
+// TODO(denis0x0D) : move to DRR.
+class ConstantCompositeOpConversion final : public SPIRVOpLowering<ConstantOp> {
+public:
+  using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
+
+  PatternMatchResult
+  matchAndRewrite(ConstantOp constCompositeOp, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 /// Convert constant operation with IndexType return to SPIR-V constant
 /// operation. Since IndexType is not used within SPIR-V dialect, this needs
 /// special handling to make sure the result type and the type of the value
@@ -173,6 +184,39 @@ static spirv::AccessChainOp getElementPtr(OpBuilder &builder,
 }
 
 //===----------------------------------------------------------------------===//
+// ConstantOp with composite type.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite(
+    ConstantOp constCompositeOp, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  auto compositeType =
+      constCompositeOp.getResult().getType().dyn_cast<RankedTensorType>();
+  if (!compositeType)
+    return matchFailure();
+
+  auto spirvCompositeType = typeConverter.convertType(compositeType);
+  if (!spirvCompositeType)
+    return matchFailure();
+
+  auto linearizedElements =
+      constCompositeOp.value().dyn_cast<DenseElementsAttr>();
+  if (!linearizedElements)
+    return matchFailure();
+
+  // If composite type has rank greater than one, then perform linearization.
+  if (compositeType.getRank() > 1) {
+    auto linearizedType = RankedTensorType::get(compositeType.getNumElements(),
+                                                compositeType.getElementType());
+    linearizedElements = linearizedElements.reshape(linearizedType);
+  }
+
+  rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+      constCompositeOp, spirvCompositeType, linearizedElements);
+  return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
 // ConstantOp with index type.
 //===----------------------------------------------------------------------===//
 
@@ -354,7 +398,8 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
                                      OwningRewritePatternList &patterns) {
   // Add patterns that lower operations into SPIR-V dialect.
   populateWithGenerated(context, &patterns);
-  patterns.insert<ConstantIndexOpConversion, CmpFOpConversion, CmpIOpConversion,
+  patterns.insert<ConstantCompositeOpConversion, ConstantIndexOpConversion,
+                  CmpFOpConversion, CmpIOpConversion,
                   IntegerOpConversion<AddIOp, spirv::IAddOp>,
                   IntegerOpConversion<MulIOp, spirv::IMulOp>,
                   IntegerOpConversion<SignedDivIOp, spirv::SDivOp>,
index 2d340b7..22fb39a 100644 (file)
@@ -80,6 +80,19 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
       memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
     }
     return (offset + memrefSize) * elementSize.getValue();
+  } else if (auto tensorType = t.dyn_cast<TensorType>()) {
+    if (!tensorType.hasStaticShape()) {
+      return llvm::None;
+    }
+    auto elementSize = getTypeNumBytes(tensorType.getElementType());
+    if (!elementSize) {
+      return llvm::None;
+    }
+    int64_t size = elementSize.getValue();
+    for (auto shape : tensorType.getShape()) {
+      size *= shape;
+    }
+    return size;
   }
   // TODO: Add size computation for other types.
   return llvm::None;
@@ -131,6 +144,27 @@ static Type convertStdType(Type type) {
     }
   }
 
+  if (auto tensorType = type.dyn_cast<TensorType>()) {
+    // TODO(ravishankarm) : Handle dynamic shapes.
+    if (!tensorType.hasStaticShape()) {
+      return Type();
+    }
+    auto elementType = convertStdType(tensorType.getElementType());
+    if (!elementType) {
+      return Type();
+    }
+    auto elementSize = getTypeNumBytes(elementType);
+    if (!elementSize) {
+      return Type();
+    }
+    auto tensorSize = getTypeNumBytes(tensorType);
+    if (!tensorSize) {
+      return Type();
+    }
+    return spirv::ArrayType::get(elementType,
+                                 tensorSize.getValue() / elementSize.getValue(),
+                                 elementSize.getValue());
+  }
   return Type();
 }
 
index 7d0020a..52f039c 100644 (file)
@@ -220,6 +220,18 @@ func @constant() {
   %3 = constant dense<[2, 3]> : vector<2xi32>
   // CHECK: spv.constant 1 : i32
   %4 = constant 1 : index
+  // CHECK: spv.constant dense<1> : tensor<6xi32> : !spv.array<6 x i32 [4]>
+  %5 = constant dense<1> : tensor<2x3xi32>
+  // CHECK: spv.constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32 [4]>
+  %6 = constant dense<1.0> : tensor<2x3xf32>
+  // CHECK: spv.constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32 [4]>
+  %7 = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
+  // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
+  %8 = constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+  // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
+  %9 =  constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
+  // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
+  %10 =  constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
   return
 }