[mlir][spirv] Refactoring to avoid calling the same function twice
authorLei Zhang <antiagainst@google.com>
Wed, 26 Feb 2020 20:35:46 +0000 (15:35 -0500)
committerLei Zhang <antiagainst@google.com>
Wed, 26 Feb 2020 20:36:54 +0000 (15:36 -0500)
mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp

index 2d1a66c..c705dc8 100644 (file)
@@ -24,24 +24,23 @@ using namespace mlir;
 // Common utility functions
 //===----------------------------------------------------------------------===//
 
-/// Returns true if the given `irVal` is a scalar or splat vector constant of
-/// the given `boolVal`.
-static bool isScalarOrSplatBoolAttr(Attribute boolAttr, bool boolVal) {
+/// Returns the boolean value under the hood if the given `boolAttr` is a scalar
+/// or splat vector bool constant.
+static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
   if (!boolAttr)
-    return false;
+    return llvm::None;
 
   auto type = boolAttr.getType();
   if (type.isInteger(1)) {
     auto attr = boolAttr.cast<BoolAttr>();
-    return attr.getValue() == boolVal;
+    return attr.getValue();
   }
   if (auto vecType = type.cast<VectorType>()) {
     if (vecType.getElementType().isInteger(1))
       if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
-        return attr.getSplatValue().template cast<BoolAttr>().getValue() ==
-               boolVal;
+        return attr.getSplatValue<bool>();
   }
-  return false;
+  return llvm::None;
 }
 
 // Extracts an element from the given `composite` by following the given
@@ -214,13 +213,15 @@ OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
 OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
 
-  // x && true = x
-  if (isScalarOrSplatBoolAttr(operands.back(), true))
-    return operand1();
+  if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
+    // x && true = x
+    if (rhs.getValue())
+      return operand1();
 
-  // x && false = false
-  if (isScalarOrSplatBoolAttr(operands.back(), false))
-    return operands.back();
+    // x && false = false
+    if (!rhs.getValue())
+      return operands.back();
+  }
 
   return Attribute();
 }
@@ -243,13 +244,15 @@ void spirv::LogicalNotOp::getCanonicalizationPatterns(
 OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
 
-  // x || true = true
-  if (isScalarOrSplatBoolAttr(operands.back(), true))
-    return operands.back();
+  if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
+    if (rhs.getValue())
+      // x || true = true
+      return operands.back();
 
-  // x || false = x
-  if (isScalarOrSplatBoolAttr(operands.back(), false))
-    return operand1();
+    // x || false = x
+    if (!rhs.getValue())
+      return operand1();
+  }
 
   return Attribute();
 }