namespace {
+/// Convert composite constant operation to SPIR-V dialect.
+// TODO(denis0x0D) : move to DRR.
+class ConstantCompositeOpConversion final : public SPIRVOpLowering<ConstantOp> {
+public:
+ using SPIRVOpLowering<ConstantOp>::SPIRVOpLowering;
+
+ PatternMatchResult
+ matchAndRewrite(ConstantOp constCompositeOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
/// Convert constant operation with IndexType return to SPIR-V constant
/// operation. Since IndexType is not used within SPIR-V dialect, this needs
/// special handling to make sure the result type and the type of the value
}
//===----------------------------------------------------------------------===//
+// ConstantOp with composite type.
+//===----------------------------------------------------------------------===//
+
+PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite(
+ ConstantOp constCompositeOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ auto compositeType =
+ constCompositeOp.getResult().getType().dyn_cast<RankedTensorType>();
+ if (!compositeType)
+ return matchFailure();
+
+ auto spirvCompositeType = typeConverter.convertType(compositeType);
+ if (!spirvCompositeType)
+ return matchFailure();
+
+ auto linearizedElements =
+ constCompositeOp.value().dyn_cast<DenseElementsAttr>();
+ if (!linearizedElements)
+ return matchFailure();
+
+ // If composite type has rank greater than one, then perform linearization.
+ if (compositeType.getRank() > 1) {
+ auto linearizedType = RankedTensorType::get(compositeType.getNumElements(),
+ compositeType.getElementType());
+ linearizedElements = linearizedElements.reshape(linearizedType);
+ }
+
+ rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+ constCompositeOp, spirvCompositeType, linearizedElements);
+ return matchSuccess();
+}
+
+//===----------------------------------------------------------------------===//
// ConstantOp with index type.
//===----------------------------------------------------------------------===//
OwningRewritePatternList &patterns) {
// Add patterns that lower operations into SPIR-V dialect.
populateWithGenerated(context, &patterns);
- patterns.insert<ConstantIndexOpConversion, CmpFOpConversion, CmpIOpConversion,
+ patterns.insert<ConstantCompositeOpConversion, ConstantIndexOpConversion,
+ CmpFOpConversion, CmpIOpConversion,
IntegerOpConversion<AddIOp, spirv::IAddOp>,
IntegerOpConversion<MulIOp, spirv::IMulOp>,
IntegerOpConversion<SignedDivIOp, spirv::SDivOp>,
memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
}
return (offset + memrefSize) * elementSize.getValue();
+ } else if (auto tensorType = t.dyn_cast<TensorType>()) {
+ if (!tensorType.hasStaticShape()) {
+ return llvm::None;
+ }
+ auto elementSize = getTypeNumBytes(tensorType.getElementType());
+ if (!elementSize) {
+ return llvm::None;
+ }
+ int64_t size = elementSize.getValue();
+ for (auto shape : tensorType.getShape()) {
+ size *= shape;
+ }
+ return size;
}
// TODO: Add size computation for other types.
return llvm::None;
}
}
+ if (auto tensorType = type.dyn_cast<TensorType>()) {
+ // TODO(ravishankarm) : Handle dynamic shapes.
+ if (!tensorType.hasStaticShape()) {
+ return Type();
+ }
+ auto elementType = convertStdType(tensorType.getElementType());
+ if (!elementType) {
+ return Type();
+ }
+ auto elementSize = getTypeNumBytes(elementType);
+ if (!elementSize) {
+ return Type();
+ }
+ auto tensorSize = getTypeNumBytes(tensorType);
+ if (!tensorSize) {
+ return Type();
+ }
+ return spirv::ArrayType::get(elementType,
+ tensorSize.getValue() / elementSize.getValue(),
+ elementSize.getValue());
+ }
return Type();
}
%3 = constant dense<[2, 3]> : vector<2xi32>
// CHECK: spv.constant 1 : i32
%4 = constant 1 : index
+ // CHECK: spv.constant dense<1> : tensor<6xi32> : !spv.array<6 x i32 [4]>
+ %5 = constant dense<1> : tensor<2x3xi32>
+ // CHECK: spv.constant dense<1.000000e+00> : tensor<6xf32> : !spv.array<6 x f32 [4]>
+ %6 = constant dense<1.0> : tensor<2x3xf32>
+ // CHECK: spv.constant dense<{{\[}}1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf32> : !spv.array<6 x f32 [4]>
+ %7 = constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
+ // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
+ %8 = constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
+ // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
+ %9 = constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
+ // CHECK: spv.constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spv.array<6 x i32 [4]>
+ %10 = constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
return
}