[OPENMP][C++20]Add support for CXXRewrittenBinaryOperator in ranged for loops.
authorAlexey Bataev <a.bataev@outlook.com>
Mon, 14 Jun 2021 13:03:42 +0000 (06:03 -0700)
committerAlexey Bataev <a.bataev@outlook.com>
Mon, 14 Jun 2021 18:50:27 +0000 (11:50 -0700)
Added support for CXXRewrittenBinaryOperator as a condition in ranged
for loops. This is a new kind of expression, need to extend support for
  C++20 constructs.
It fixes PR49970: range-based for compilation fails for libstdc++ vector
with -std=c++20.

Differential Revision: https://reviews.llvm.org/D104240

clang/lib/Sema/SemaOpenMP.cpp
clang/test/OpenMP/for_ast_print_cxx20.cpp [new file with mode: 0644]

index a67f091..0e39b8e 100644 (file)
@@ -7668,53 +7668,43 @@ bool OpenMPIterationSpaceChecker::checkAndSetCond(Expr *S) {
   Condition = S;
   S = getExprAsWritten(S);
   SourceLocation CondLoc = S->getBeginLoc();
-  if (auto *BO = dyn_cast<BinaryOperator>(S)) {
-    if (BO->isRelationalOp()) {
-      if (getInitLCDecl(BO->getLHS()) == LCDecl)
-        return setUB(BO->getRHS(),
-                     (BO->getOpcode() == BO_LT || BO->getOpcode() == BO_LE),
-                     (BO->getOpcode() == BO_LT || BO->getOpcode() == BO_GT),
-                     BO->getSourceRange(), BO->getOperatorLoc());
-      if (getInitLCDecl(BO->getRHS()) == LCDecl)
-        return setUB(BO->getLHS(),
-                     (BO->getOpcode() == BO_GT || BO->getOpcode() == BO_GE),
-                     (BO->getOpcode() == BO_LT || BO->getOpcode() == BO_GT),
-                     BO->getSourceRange(), BO->getOperatorLoc());
-    } else if (IneqCondIsCanonical && BO->getOpcode() == BO_NE)
-      return setUB(
-          getInitLCDecl(BO->getLHS()) == LCDecl ? BO->getRHS() : BO->getLHS(),
-          /*LessOp=*/llvm::None,
-          /*StrictOp=*/true, BO->getSourceRange(), BO->getOperatorLoc());
+  auto &&CheckAndSetCond = [this, IneqCondIsCanonical](
+                               BinaryOperatorKind Opcode, const Expr *LHS,
+                               const Expr *RHS, SourceRange SR,
+                               SourceLocation OpLoc) -> llvm::Optional<bool> {
+    if (BinaryOperator::isRelationalOp(Opcode)) {
+      if (getInitLCDecl(LHS) == LCDecl)
+        return setUB(const_cast<Expr *>(RHS),
+                     (Opcode == BO_LT || Opcode == BO_LE),
+                     (Opcode == BO_LT || Opcode == BO_GT), SR, OpLoc);
+      if (getInitLCDecl(RHS) == LCDecl)
+        return setUB(const_cast<Expr *>(LHS),
+                     (Opcode == BO_GT || Opcode == BO_GE),
+                     (Opcode == BO_LT || Opcode == BO_GT), SR, OpLoc);
+    } else if (IneqCondIsCanonical && Opcode == BO_NE) {
+      return setUB(const_cast<Expr *>(getInitLCDecl(LHS) == LCDecl ? RHS : LHS),
+                   /*LessOp=*/llvm::None,
+                   /*StrictOp=*/true, SR, OpLoc);
+    }
+    return llvm::None;
+  };
+  llvm::Optional<bool> Res;
+  if (auto *RBO = dyn_cast<CXXRewrittenBinaryOperator>(S)) {
+    CXXRewrittenBinaryOperator::DecomposedForm DF = RBO->getDecomposedForm();
+    Res = CheckAndSetCond(DF.Opcode, DF.LHS, DF.RHS, RBO->getSourceRange(),
+                          RBO->getOperatorLoc());
+  } else if (auto *BO = dyn_cast<BinaryOperator>(S)) {
+    Res = CheckAndSetCond(BO->getOpcode(), BO->getLHS(), BO->getRHS(),
+                          BO->getSourceRange(), BO->getOperatorLoc());
   } else if (auto *CE = dyn_cast<CXXOperatorCallExpr>(S)) {
     if (CE->getNumArgs() == 2) {
-      auto Op = CE->getOperator();
-      switch (Op) {
-      case OO_Greater:
-      case OO_GreaterEqual:
-      case OO_Less:
-      case OO_LessEqual:
-        if (getInitLCDecl(CE->getArg(0)) == LCDecl)
-          return setUB(CE->getArg(1), Op == OO_Less || Op == OO_LessEqual,
-                       Op == OO_Less || Op == OO_Greater, CE->getSourceRange(),
-                       CE->getOperatorLoc());
-        if (getInitLCDecl(CE->getArg(1)) == LCDecl)
-          return setUB(CE->getArg(0), Op == OO_Greater || Op == OO_GreaterEqual,
-                       Op == OO_Less || Op == OO_Greater, CE->getSourceRange(),
-                       CE->getOperatorLoc());
-        break;
-      case OO_ExclaimEqual:
-        if (IneqCondIsCanonical)
-          return setUB(getInitLCDecl(CE->getArg(0)) == LCDecl ? CE->getArg(1)
-                                                              : CE->getArg(0),
-                       /*LessOp=*/llvm::None,
-                       /*StrictOp=*/true, CE->getSourceRange(),
-                       CE->getOperatorLoc());
-        break;
-      default:
-        break;
-      }
+      Res = CheckAndSetCond(
+          BinaryOperator::getOverloadedOpcode(CE->getOperator()), CE->getArg(0),
+          CE->getArg(1), CE->getSourceRange(), CE->getOperatorLoc());
     }
   }
+  if (Res.hasValue())
+    return *Res;
   if (dependent() || SemaRef.CurContext->isDependentContext())
     return false;
   SemaRef.Diag(CondLoc, diag::err_omp_loop_not_canonical_cond)
diff --git a/clang/test/OpenMP/for_ast_print_cxx20.cpp b/clang/test/OpenMP/for_ast_print_cxx20.cpp
new file mode 100644 (file)
index 0000000..800ce12
--- /dev/null
@@ -0,0 +1,40 @@
+// RUN: %clang_cc1 -verify -fopenmp --std=c++20 -ast-print %s -Wsign-conversion | FileCheck %s
+// RUN: %clang_cc1 -fopenmp -x c++ -std=c++20 -emit-pch -o %t %s
+// RUN: %clang_cc1 -fopenmp -std=c++20 -include-pch %t -fsyntax-only -verify %s -ast-print | FileCheck %s
+
+// RUN: %clang_cc1 -verify -fopenmp-simd --std=c++20 -ast-print %s -Wsign-conversion | FileCheck %s
+// RUN: %clang_cc1 -fopenmp-simd -x c++ -std=c++20 -emit-pch -o %t %s
+// RUN: %clang_cc1 -fopenmp-simd -std=c++20 -include-pch %t -fsyntax-only -verify %s -ast-print | FileCheck %s
+// expected-no-diagnostics
+
+#ifndef HEADER
+#define HEADER
+
+template <typename T> class iterator {
+public:
+  T &operator*() const;
+  iterator &operator++();
+};
+template <typename T>
+bool operator==(const iterator<T> &, const iterator<T> &);
+template <typename T>
+unsigned long operator-(const iterator<T> &, const iterator<T> &);
+template <typename T>
+iterator<T> operator+(const iterator<T> &, unsigned long);
+class vector {
+public:
+  vector();
+  iterator<int> begin();
+  iterator<int> end();
+};
+// CHECK: void foo() {
+void foo() {
+// CHECK-NEXT: vector vec;
+  vector vec;
+// CHECK-NEXT: #pragma omp for
+#pragma omp for
+// CHECK-NEXT: for (int i : vec)
+  for (int i : vec)
+    ;
+}
+#endif