From cd556f25deef446042f018d07cb5bc4581f09e82 Mon Sep 17 00:00:00 2001 From: Denis Khalikov Date: Thu, 28 Nov 2019 13:27:26 -0800 Subject: [PATCH] [spirv] Check that operand of `spirv::CompositeExtractOp` is constant while folding. Closes tensorflow/mlir#281 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/281 from denis0x0D:sandbox/composite_ex_fold d02d73658bd1b9eaa515eb4e0aee34bc41d4252b PiperOrigin-RevId: 282971563 --- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 3 +++ mlir/test/Dialect/SPIRV/canonicalize.mlir | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index e824200..e8896fa 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -336,6 +336,9 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, // `indices`. Returns a null Attribute if error happens. static Attribute extractCompositeElement(Attribute composite, ArrayRef indices) { + // Check that given composite is a constant. + if (!composite) + return {}; // Return composite itself if we reach the end of the index chain. if (indices.empty()) return composite; diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir index 9df7b09..b721e4e 100644 --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -126,6 +126,17 @@ func @extract_array_interm() -> (vector<2xi32>) { // ----- +// CHECK-LABEL: extract_from_not_constant +func @extract_from_not_constant() -> i32 { + %0 = spv.Variable : !spv.ptr, Function> + %1 = spv.Load "Function" %0 : vector<3xi32> + // CHECK: spv.CompositeExtract + %2 = spv.CompositeExtract %1[0 : i32] : vector<3xi32> + spv.ReturnValue %2 : i32 +} + +// ----- + //===----------------------------------------------------------------------===// // spv.constant //===----------------------------------------------------------------------===// -- 2.7.4