[spirv] Check that operand of `spirv::CompositeExtractOp` is constant while folding.
authorDenis Khalikov <khalikov.denis@huawei.com>
Thu, 28 Nov 2019 21:27:26 +0000 (13:27 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 28 Nov 2019 21:27:56 +0000 (13:27 -0800)
Closes tensorflow/mlir#281

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/281 from denis0x0D:sandbox/composite_ex_fold d02d73658bd1b9eaa515eb4e0aee34bc41d4252b
PiperOrigin-RevId: 282971563

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/canonicalize.mlir

index e82420022ea1eb501b6f7c9cf2fe68dde8b685a7..e8896fac5261d79af4e09958c910592533f01eb6 100644 (file)
@@ -336,6 +336,9 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
 // `indices`. Returns a null Attribute if error happens.
 static Attribute extractCompositeElement(Attribute composite,
                                          ArrayRef<unsigned> indices) {
+  // Check that given composite is a constant.
+  if (!composite)
+    return {};
   // Return composite itself if we reach the end of the index chain.
   if (indices.empty())
     return composite;
index 9df7b09b8e4f20c79cf41ced8248372e78344363..b721e4eaec4d1c1c4c5274708bdf230273956f40 100644 (file)
@@ -126,6 +126,17 @@ func @extract_array_interm() -> (vector<2xi32>) {
 
 // -----
 
+// CHECK-LABEL: extract_from_not_constant
+func @extract_from_not_constant() -> i32 {
+  %0 = spv.Variable : !spv.ptr<vector<3xi32>, Function>
+  %1 = spv.Load "Function" %0 : vector<3xi32>
+  // CHECK: spv.CompositeExtract
+  %2 = spv.CompositeExtract %1[0 : i32] : vector<3xi32>
+  spv.ReturnValue %2 : i32
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spv.constant
 //===----------------------------------------------------------------------===//