[flang] Add fir.select_type conversion to if-then-else ladder
authorValentin Clement <clementval@gmail.com>
Mon, 21 Nov 2022 09:00:17 +0000 (10:00 +0100)
committerValentin Clement <clementval@gmail.com>
Mon, 21 Nov 2022 09:01:42 +0000 (10:01 +0100)
Convert fir.select_type operation to an if-then-else ladder.
The type guards are sorted before the conversion so it follows the
execution of SELECT TYPE construct as mentioned in 11.1.11.2 point 4
of the Fortran standard.

Depends on D138279

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D138280

flang/include/flang/Optimizer/Transforms/Passes.td
flang/lib/Optimizer/Dialect/FIROps.cpp
flang/lib/Optimizer/Dialect/FIRType.cpp
flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp
flang/test/Lower/select-type.f90

index b35aa82..4ed3b82 100644 (file)
@@ -140,9 +140,9 @@ def CharacterConversion : Pass<"character-conversion"> {
 def CFGConversion : Pass<"cfg-conversion", "::mlir::func::FuncOp"> {
   let summary = "Convert FIR structured control flow ops to CFG ops.";
   let description = [{
-    Transform the `fir.do_loop`, `fir.if`, and `fir.iterate_while` ops into
-    plain old test and branch operations. Removing the high-level control
-    structures can enable other optimizations.
+    Transform the `fir.do_loop`, `fir.if`, `fir.iterate_while` and
+    `fir.select_type` ops into plain old test and branch operations. Removing
+    the high-level control structures can enable other optimizations.
 
     This pass is required before code gen to the LLVM IR dialect.
   }];
index 016c691..522da2f 100644 (file)
@@ -2934,6 +2934,16 @@ fir::SelectTypeOp::getSuccessorOperands(llvm::ArrayRef<mlir::Value> operands,
   return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
 }
 
+llvm::Optional<mlir::ValueRange>
+fir::SelectTypeOp::getSuccessorOperands(mlir::ValueRange operands,
+                                        unsigned oper) {
+  auto a =
+      (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(getTargetOffsetAttr());
+  auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(
+      getOperandSegmentSizeAttr());
+  return {getSubOperands(oper, getSubOperands(2, operands, segments), a)};
+}
+
 mlir::ParseResult fir::SelectTypeOp::parse(mlir::OpAsmParser &parser,
                                            mlir::OperationState &result) {
   mlir::OpAsmParser::UnresolvedOperand selector;
@@ -3011,8 +3021,11 @@ mlir::LogicalResult fir::SelectTypeOp::verify() {
   if (auto boxType = getSelector().getType().dyn_cast<fir::BoxType>())
     if (!boxType.getEleTy().isa<mlir::NoneType>())
       return emitOpError("selector must be polymorphic");
-  auto cases =
-      getOperation()->getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue();
+  auto typeGuardAttr = getCases();
+  for (unsigned idx = 0; idx < typeGuardAttr.size(); ++idx)
+    if (typeGuardAttr[idx].isa<mlir::UnitAttr>() &&
+        idx != typeGuardAttr.size() - 1)
+      return emitOpError("default must be the last attribute");
   auto count = getNumDest();
   if (count == 0)
     return emitOpError("must have at least one successor");
@@ -3020,10 +3033,10 @@ mlir::LogicalResult fir::SelectTypeOp::verify() {
     return emitOpError("number of conditions and successors don't match");
   if (targetOffsetSize() != count)
     return emitOpError("incorrect number of successor operand groups");
-  for (decltype(count) i = 0; i != count; ++i) {
-    auto &attr = cases[i];
-    if (!(attr.isa<fir::ExactTypeAttr>() || attr.isa<fir::SubclassAttr>() ||
-          attr.isa<mlir::UnitAttr>()))
+  for (unsigned i = 0; i != count; ++i) {
+    if (!(typeGuardAttr[i].isa<fir::ExactTypeAttr>() ||
+          typeGuardAttr[i].isa<fir::SubclassAttr>() ||
+          typeGuardAttr[i].isa<mlir::UnitAttr>()))
       return emitOpError("invalid type-case alternative");
   }
   return mlir::success();
index 67b4d1a..f0b024e 100644 (file)
@@ -486,7 +486,8 @@ mlir::LogicalResult
 fir::ClassType::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
                        mlir::Type eleTy) {
   if (eleTy.isa<fir::RecordType, fir::SequenceType, fir::HeapType,
-                fir::PointerType, mlir::NoneType>())
+                fir::PointerType, mlir::NoneType, mlir::IntegerType,
+                mlir::FloatType>())
     return mlir::success();
   return emitError() << "invalid element type\n";
 }
index 72ab70c..af13fec 100644 (file)
@@ -8,13 +8,21 @@
 
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIROpsSupport.h"
+#include "flang/Optimizer/Support/FIRContext.h"
+#include "flang/Optimizer/Support/InternalNames.h"
+#include "flang/Optimizer/Support/KindMapping.h"
+#include "flang/Optimizer/Support/TypeCode.h"
 #include "flang/Optimizer/Transforms/Passes.h"
+#include "flang/Runtime/derived-api.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/Support/CommandLine.h"
+#include <mutex>
 
 namespace fir {
 #define GEN_PASS_DEF_CFGCONVERSION
@@ -300,20 +308,273 @@ public:
   }
 };
 
+/// SelectTypeOp converted to an if-then-else chain
+///
+/// This lowers the test conditions to calls into the runtime.
+class CfgSelectTypeConv : public OpConversionPattern<fir::SelectTypeOp> {
+public:
+  using OpConversionPattern<fir::SelectTypeOp>::OpConversionPattern;
+
+  CfgSelectTypeConv(mlir::MLIRContext *ctx, std::mutex *moduleMutex)
+      : mlir::OpConversionPattern<fir::SelectTypeOp>(ctx),
+        moduleMutex(moduleMutex) {}
+
+  mlir::LogicalResult
+  matchAndRewrite(fir::SelectTypeOp selectType, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    auto operands = adaptor.getOperands();
+    auto typeGuards = selectType.getCases();
+    unsigned typeGuardNum = typeGuards.size();
+    auto selector = selectType.getSelector();
+    auto loc = selectType.getLoc();
+    auto mod = selectType.getOperation()->getParentOfType<mlir::ModuleOp>();
+    fir::KindMapping kindMap = fir::getKindMapping(mod);
+
+    // Order type guards so the condition and branches are done to respect the
+    // Execution of SELECT TYPE construct as described in the Fortran 2018
+    // standard 11.1.11.2 point 4.
+    // 1. If a TYPE IS type guard statement matches the selector, the block
+    //    following that statement is executed.
+    // 2. Otherwise, if exactly one CLASS IS type guard statement matches the
+    //    selector, the block following that statement is executed.
+    // 3. Otherwise, if several CLASS IS type guard statements match the
+    //    selector, one of these statements will inevitably specify a type that
+    //    is an extension of all the types specified in the others; the block
+    //    following that statement is executed.
+    // 4. Otherwise, if there is a CLASS DEFAULT type guard statement, the block
+    //    following that statement is executed.
+    // 5. Otherwise, no block is executed.
+
+    llvm::SmallVector<unsigned> orderedTypeGuards;
+    llvm::SmallVector<unsigned> orderedClassIsGuards;
+    unsigned defaultGuard = typeGuardNum - 1;
+
+    // The following loop go through the type guards in the fir.select_type
+    // operation and sort them into two lists.
+    // - All the TYPE IS type guard are added in order to the orderedTypeGuards
+    //   list. This list is used at the end to generate the if-then-else ladder.
+    // - CLASS IS type guard are added in a separate list. If a CLASS IS type
+    //   guard type extends a type already present, the type guard is inserted
+    //   before in the list to respect point 3. above. Otherwise it is just
+    //   added in order at the end.
+    for (unsigned t = 0; t < typeGuardNum; ++t) {
+      if (auto a = typeGuards[t].dyn_cast<fir::ExactTypeAttr>()) {
+        orderedTypeGuards.push_back(t);
+        continue;
+      }
+
+      if (auto a = typeGuards[t].dyn_cast<fir::SubclassAttr>()) {
+        if (auto recTy = a.getType().dyn_cast<fir::RecordType>()) {
+          auto dt = mod.lookupSymbol<fir::DispatchTableOp>(recTy.getName());
+          assert(dt && "dispatch table not found");
+          llvm::SmallSet<llvm::StringRef, 4> ancestors =
+              collectAncestors(dt, mod);
+          if (!ancestors.empty()) {
+            auto it = orderedClassIsGuards.begin();
+            while (it != orderedClassIsGuards.end()) {
+              fir::SubclassAttr sAttr =
+                  typeGuards[*it].dyn_cast<fir::SubclassAttr>();
+              if (auto ty = sAttr.getType().dyn_cast<fir::RecordType>()) {
+                if (ancestors.contains(ty.getName()))
+                  break;
+              }
+              ++it;
+            }
+            if (it != orderedClassIsGuards.end()) {
+              // Parent type is present so place it before.
+              orderedClassIsGuards.insert(it, t);
+              continue;
+            }
+          }
+        }
+        orderedClassIsGuards.push_back(t);
+      }
+    }
+    orderedTypeGuards.append(orderedClassIsGuards);
+    orderedTypeGuards.push_back(defaultGuard);
+    assert(orderedTypeGuards.size() == typeGuardNum &&
+           "ordered type guard size doesn't match number of type guards");
+
+    for (unsigned idx : orderedTypeGuards) {
+      auto *dest = selectType.getSuccessor(idx);
+      llvm::Optional<mlir::ValueRange> destOps =
+          selectType.getSuccessorOperands(operands, idx);
+      if (typeGuards[idx].dyn_cast<mlir::UnitAttr>())
+        rewriter.replaceOpWithNewOp<mlir::cf::BranchOp>(selectType, dest);
+      else if (mlir::failed(genTypeLadderStep(loc, selector, typeGuards[idx],
+                                              dest, destOps, mod, rewriter,
+                                              kindMap)))
+        return mlir::failure();
+    }
+    return mlir::success();
+  }
+
+  llvm::SmallSet<llvm::StringRef, 4>
+  collectAncestors(fir::DispatchTableOp dt, mlir::ModuleOp mod) const {
+    llvm::SmallSet<llvm::StringRef, 4> ancestors;
+    if (!dt.getParent().has_value())
+      return ancestors;
+    while (dt.getParent().has_value()) {
+      ancestors.insert(*dt.getParent());
+      dt = mod.lookupSymbol<fir::DispatchTableOp>(*dt.getParent());
+    }
+    return ancestors;
+  }
+
+  // Generate comparison of type descriptor addresses.
+  mlir::Value genTypeDescCompare(mlir::Location loc, mlir::Value selector,
+                                 mlir::Type ty, mlir::ModuleOp mod,
+                                 mlir::PatternRewriter &rewriter) const {
+    assert(ty.isa<fir::RecordType>() && "expect fir.record type");
+    fir::RecordType recTy = ty.dyn_cast<fir::RecordType>();
+    std::string typeDescName =
+        fir::NameUniquer::getTypeDescriptorName(recTy.getName());
+    auto typeDescGlobal = mod.lookupSymbol<fir::GlobalOp>(typeDescName);
+    if (!typeDescGlobal)
+      return {};
+    auto typeDescAddr = rewriter.create<fir::AddrOfOp>(
+        loc, fir::ReferenceType::get(typeDescGlobal.getType()),
+        typeDescGlobal.getSymbol());
+    auto intPtrTy = rewriter.getIndexType();
+    mlir::Type tdescType =
+        fir::TypeDescType::get(mlir::NoneType::get(rewriter.getContext()));
+    mlir::Value selectorTdescAddr =
+        rewriter.create<fir::BoxTypeDescOp>(loc, tdescType, selector);
+    auto typeDescInt =
+        rewriter.create<fir::ConvertOp>(loc, intPtrTy, typeDescAddr);
+    auto selectorTdescInt =
+        rewriter.create<fir::ConvertOp>(loc, intPtrTy, selectorTdescAddr);
+    return rewriter.create<mlir::arith::CmpIOp>(
+        loc, mlir::arith::CmpIPredicate::eq, typeDescInt, selectorTdescInt);
+  }
+
+  static int getTypeCode(mlir::Type ty, fir::KindMapping &kindMap) {
+    if (auto intTy = ty.dyn_cast<mlir::IntegerType>())
+      return fir::integerBitsToTypeCode(intTy.getWidth());
+    if (auto floatTy = ty.dyn_cast<mlir::FloatType>())
+      return fir::realBitsToTypeCode(floatTy.getWidth());
+    if (auto logicalTy = ty.dyn_cast<fir::LogicalType>())
+      return fir::logicalBitsToTypeCode(
+          kindMap.getLogicalBitsize(logicalTy.getFKind()));
+    if (fir::isa_complex(ty)) {
+      if (auto cmplxTy = ty.dyn_cast<mlir::ComplexType>())
+        return fir::complexBitsToTypeCode(
+            cmplxTy.getElementType().cast<mlir::FloatType>().getWidth());
+      auto cmplxTy = ty.cast<fir::ComplexType>();
+      return fir::complexBitsToTypeCode(
+          kindMap.getRealBitsize(cmplxTy.getFKind()));
+    }
+    return 0; // TODO more types.
+  }
+
+  mlir::LogicalResult
+  genTypeLadderStep(mlir::Location loc, mlir::Value selector,
+                    mlir::Attribute attr, mlir::Block *dest,
+                    llvm::Optional<mlir::ValueRange> destOps,
+                    mlir::ModuleOp mod, mlir::PatternRewriter &rewriter,
+                    fir::KindMapping &kindMap) const {
+    mlir::Value cmp;
+    // TYPE IS type guard comparison are all done inlined.
+    if (auto a = attr.dyn_cast<fir::ExactTypeAttr>()) {
+      if (fir::isa_trivial(a.getType())) {
+        // For type guard statement with Intrinsic type spec the type code of
+        // the descriptor is compared.
+        int code = getTypeCode(a.getType(), kindMap);
+        if (code == 0)
+          return mlir::emitError(loc)
+                 << "type code not done for " << a.getType();
+        mlir::Value typeCode = rewriter.create<mlir::arith::ConstantOp>(
+            loc, rewriter.getI8IntegerAttr(code));
+        mlir::Value selectorTypeCode = rewriter.create<fir::BoxTypeCodeOp>(
+            loc, rewriter.getI8Type(), selector);
+        cmp = rewriter.create<mlir::arith::CmpIOp>(
+            loc, mlir::arith::CmpIPredicate::eq, selectorTypeCode, typeCode);
+      } else {
+        // Flang inline the kind parameter in the type descriptor so we can
+        // directly check if the type descriptor addresses are identical for
+        // the TYPE IS type guard statement.
+        mlir::Value res =
+            genTypeDescCompare(loc, selector, a.getType(), mod, rewriter);
+        if (!res)
+          return mlir::failure();
+        cmp = res;
+      }
+      // CLASS IS type guard statement is done with a runtime call.
+    } else if (auto a = attr.dyn_cast<fir::SubclassAttr>()) {
+      // Retrieve the type descriptor from the type guard statement record type.
+      assert(a.getType().isa<fir::RecordType>() && "expect fir.record type");
+      fir::RecordType recTy = a.getType().dyn_cast<fir::RecordType>();
+      std::string typeDescName =
+          fir::NameUniquer::getTypeDescriptorName(recTy.getName());
+      auto typeDescGlobal = mod.lookupSymbol<fir::GlobalOp>(typeDescName);
+      auto typeDescAddr = rewriter.create<fir::AddrOfOp>(
+          loc, fir::ReferenceType::get(typeDescGlobal.getType()),
+          typeDescGlobal.getSymbol());
+      mlir::Type typeDescTy = ReferenceType::get(rewriter.getNoneType());
+      mlir::Value typeDesc =
+          rewriter.create<ConvertOp>(loc, typeDescTy, typeDescAddr);
+
+      // Prepare the selector descriptor for the runtime call.
+      mlir::Type descNoneTy = fir::BoxType::get(rewriter.getNoneType());
+      mlir::Value descSelector =
+          rewriter.create<ConvertOp>(loc, descNoneTy, selector);
+
+      // Generate runtime call.
+      llvm::StringRef fctName = RTNAME_STRING(ClassIs);
+      mlir::func::FuncOp callee;
+      {
+        // Since conversion is done in parallel for each fir.select_type
+        // operation, the runtime function insertion must be threadsafe.
+        std::lock_guard<std::mutex> lock(*moduleMutex);
+        callee =
+            fir::createFuncOp(rewriter.getUnknownLoc(), mod, fctName,
+                              rewriter.getFunctionType({descNoneTy, typeDescTy},
+                                                       rewriter.getI1Type()));
+      }
+      cmp = rewriter
+                .create<fir::CallOp>(loc, callee,
+                                     mlir::ValueRange{descSelector, typeDesc})
+                .getResult(0);
+    }
+
+    auto *thisBlock = rewriter.getInsertionBlock();
+    auto *newBlock =
+        rewriter.createBlock(dest->getParent(), mlir::Region::iterator(dest));
+    rewriter.setInsertionPointToEnd(thisBlock);
+    if (destOps.has_value())
+      rewriter.create<mlir::cf::CondBranchOp>(loc, cmp, dest, destOps.value(),
+                                              newBlock, llvm::None);
+    else
+      rewriter.create<mlir::cf::CondBranchOp>(loc, cmp, dest, newBlock);
+    rewriter.setInsertionPointToEnd(newBlock);
+    return mlir::success();
+  }
+
+private:
+  // Mutex used to guard insertion of mlir::func::FuncOp in the module.
+  std::mutex *moduleMutex;
+};
+
 /// Convert FIR structured control flow ops to CFG ops.
 class CfgConversion : public fir::impl::CFGConversionBase<CfgConversion> {
 public:
+  mlir::LogicalResult initialize(mlir::MLIRContext *ctx) override {
+    moduleMutex = new std::mutex();
+    return mlir::success();
+  }
+
   void runOnOperation() override {
     auto *context = &getContext();
     mlir::RewritePatternSet patterns(context);
     patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
         context, forceLoopToExecuteOnce);
+    patterns.insert<CfgSelectTypeConv>(context, moduleMutex);
     mlir::ConversionTarget target(*context);
     target.addLegalDialect<mlir::AffineDialect, mlir::cf::ControlFlowDialect,
                            FIROpsDialect, mlir::func::FuncDialect>();
 
     // apply the patterns
-    target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
+    target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp, SelectTypeOp>();
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
                                                   std::move(patterns)))) {
@@ -322,6 +583,9 @@ public:
       signalPassFailure();
     }
   }
+
+private:
+  std::mutex *moduleMutex;
 };
 } // namespace
 
index c846ea1..f14dd58 100644 (file)
@@ -1,5 +1,5 @@
 ! RUN: bbc -polymorphic-type -emit-fir %s -o - | FileCheck %s
-
+! RUN: bbc -polymorphic-type -emit-fir %s -o - | fir-opt --cfg-conversion | FileCheck --check-prefix=CFG %s
 module select_type_lower_test
   type p1
     integer :: a
@@ -15,6 +15,10 @@ module select_type_lower_test
     real(k) :: r
   end type
 
+  type, extends(p2) :: p4
+    integer :: d
+  end type
+
 contains
 
   function get_class()
@@ -49,6 +53,39 @@ contains
 ! CHECK: %{{.*}} = fir.coordinate_of %[[P2]], %[[FIELD]] : (!fir.class<!fir.type<_QMselect_type_lower_testTp2{a:i32,b:i32,c:i32}>>, !fir.field) -> !fir.ref<i32>
 ! CHECK: ^[[DEFAULT_BLOCK]]
 
+! CFG-LABEL: func.func @_QMselect_type_lower_testPselect_type1(
+! CFG-SAME: %[[ARG0:.*]]: !fir.class<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}) {
+! CFG:      %[[TDESC_P1_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p1) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:      %[[BOX_TDESC:.*]] = fir.box_tdesc %[[ARG0]] : (!fir.class<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>) -> !fir.tdesc<none>
+! CFG:      %[[TDESC_P1_CONV:.*]] = fir.convert %[[TDESC_P1_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> index
+! CFG:      %[[BOX_TDESC_CONV:.*]] = fir.convert %[[BOX_TDESC]] : (!fir.tdesc<none>) -> index
+! CFG:      %[[TDESC_CMP:.*]] = arith.cmpi eq, %[[TDESC_P1_CONV]], %[[BOX_TDESC_CONV]] : index
+! CFG:      cf.cond_br %[[TDESC_CMP]], ^[[TYPE_IS_P1_BLK:.*]], ^[[NOT_TYPE_IS_P1_BLK:.*]]
+! CFG:    ^[[NOT_TYPE_IS_P1_BLK]]:
+! CFG:      %[[TDESC_P2_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p2) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:      %[[TDESC_P2_CONV:.*]] = fir.convert %[[TDESC_P2_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> !fir.ref<none>
+! CFG:      %[[BOX_NONE:.*]] = fir.convert %[[ARG0]] : (!fir.class<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>) -> !fir.box<none>
+! CFG:      %[[CLASS_IS:.*]] = fir.call @_FortranAClassIs(%[[BOX_NONE]], %[[TDESC_P2_CONV]]) : (!fir.box<none>, !fir.ref<none>) -> i1
+! CFG:      cf.cond_br %[[CLASS_IS]], ^bb[[CLASS_IS_P2_BLK:.*]], ^[[NOT_CLASS_IS_P2_BLK:.*]]
+! CFG:    ^[[TYPE_IS_P1_BLK]]:
+! CFG:      cf.br ^bb[[EXIT_SELECT_BLK:[0-9]]]
+! CFG:    ^bb[[NOT_CLASS_IS_P1_BLK:[0-9]]]:
+! CFG:      cf.br ^bb[[DEFAULT_BLK:[0-9]]]
+! CFG:    ^bb[[CLASS_IS_P1_BLK:[0-9]]]:
+! CFG:      cf.br ^[[END_SELECT_BLK:.*]]
+! CFG:    ^[[NOT_CLASS_IS_P2_BLK]]:
+! CFG:      %[[TDESC_P1_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p1) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:      %[[TDESC_P1_CONV:.*]] = fir.convert %[[TDESC_P1_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> !fir.ref<none>
+! CFG:      %[[BOX_NONE:.*]] = fir.convert %[[ARG0]] : (!fir.class<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>) -> !fir.box<none>
+! CFG:      %[[CLASS_IS:.*]] = fir.call @_FortranAClassIs(%[[BOX_NONE]], %[[TDESC_P1_CONV]]) : (!fir.box<none>, !fir.ref<none>) -> i1
+! CFG:      cf.cond_br %[[CLASS_IS]], ^bb[[CLASS_IS_P1_BLK]], ^bb[[NOT_CLASS_IS_P1_BLK]]
+! CFG:    ^bb[[CLASS_IS_P2_BLK]]:
+! CFG:      cf.br ^[[END_SELECT_BLK]]
+! CFG:    ^bb[[DEFAULT_BLK]]:
+! CFG:      cf.br ^[[END_SELECT_BLK]]
+! CFG:    ^[[END_SELECT_BLK]]:
+! CFG:      return
+
   subroutine select_type2()
     select type (a => get_class())
     type is (p1)
@@ -71,6 +108,34 @@ contains
 ! CHECK: ^[[CLASS_IS_BLK]]
 ! CHECK: ^[[DEFAULT_BLK]]
 
+! CFG-LABEL: func.func @_QMselect_type_lower_testPselect_type2() {
+! CFG:     %[[CLASS_ALLOCA:.*]] = fir.alloca !fir.class<!fir.ptr<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>> {bindc_name = ".result"}
+! CFG:     %[[GET_CLASS:.*]] = fir.call @_QMselect_type_lower_testPget_class() {{.*}} : () -> !fir.class<!fir.ptr<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>>
+! CFG:     fir.save_result %[[GET_CLASS]] to %[[CLASS_ALLOCA]] : !fir.class<!fir.ptr<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>>, !fir.ref<!fir.class<!fir.ptr<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>>>
+! CFG:     %[[LOAD_CLASS:.*]] = fir.load %[[CLASS_ALLOCA]] : !fir.ref<!fir.class<!fir.ptr<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>>>
+! CFG:     %[[TDESC_P1_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p1) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:     %[[CLASS_TDESC:.*]] = fir.box_tdesc %[[LOAD_CLASS]] : (!fir.class<!fir.ptr<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>>) -> !fir.tdesc<none>
+! CFG:     %[[TDESC_P1_CONV:.*]] = fir.convert %[[TDESC_P1_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> index
+! CFG:     %[[BOX_TDESC_CONV:.*]] = fir.convert %[[CLASS_TDESC]] : (!fir.tdesc<none>) -> index
+! CFG:     %[[TDESC_CMP:.*]] = arith.cmpi eq, %[[TDESC_P1_CONV]], %[[BOX_TDESC_CONV]] : index
+! CFG:     cf.cond_br %[[TDESC_CMP]], ^[[TYPE_IS_P1_BLK:.*]], ^[[NOT_TYPE_IS_P1_BLK:.*]]
+! CFG:   ^[[NOT_TYPE_IS_P1_BLK]]:
+! CFG:     %[[TDESC_P1_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p1) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:     %[[TDESC_P1_CONV:.*]] = fir.convert %[[TDESC_P1_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> !fir.ref<none>
+! CFG:     %[[BOX_NONE:.*]] = fir.convert %[[LOAD_CLASS]] : (!fir.class<!fir.ptr<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>>) -> !fir.box<none>
+! CFG:     %[[CLASS_IS:.*]] = fir.call @_FortranAClassIs(%[[BOX_NONE]], %[[TDESC_P1_CONV]]) : (!fir.box<none>, !fir.ref<none>) -> i1
+! CFG:     cf.cond_br %[[CLASS_IS]], ^[[CLASS_IS_BLK:.*]], ^[[NOT_CLASS_IS_BLK:.*]]
+! CFG:   ^[[TYPE_IS_P1_BLK]]:
+! CFG:     cf.br ^bb[[EXIT_SELECT_BLK:[0-9]]]
+! CFG:   ^[[NOT_CLASS_IS_BLK]]:
+! CFG:     cf.br ^bb[[DEFAULT_BLK:[0-9]]]
+! CFG:   ^[[CLASS_IS_BLK]]:
+! CFG:     cf.br ^bb[[END_SELECT_BLK:[0-9]]]
+! CFG:   ^bb[[DEFAULT_BLK]]:
+! CFG:     cf.br ^bb[[END_SELECT_BLK:[0-9]]]
+! CFG:   ^bb[[END_SELECT_BLK:[0-9]]]:
+! CFG:     return
+
   subroutine select_type3(a)
     class(p1), pointer, intent(in) :: a(:)
 
@@ -96,6 +161,32 @@ contains
 ! CHECK: ^[[CLASS_IS_BLK]]
 ! CHECK: ^[[DEFAULT_BLK]]
 
+! CFG-LABEL: func.func @_QMselect_type_lower_testPselect_type3(
+! CFG-SAME: %[[ARG0:.*]]: !fir.ref<!fir.class<!fir.ptr<!fir.array<?x!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>>>> {fir.bindc_name = "a"}) {
+! CFG:      %[[SELECTOR:.*]] = fir.embox %{{.*}} tdesc %{{.*}} : (!fir.ref<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>, !fir.tdesc<none>) -> !fir.class<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>
+! CFG:      %[[TDESC_P1_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p1) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:      %[[SELECTOR_TDESC:.*]] = fir.box_tdesc %[[SELECTOR]] : (!fir.class<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>) -> !fir.tdesc<none>
+! CFG:      %[[TDESC_P1_CONV:.*]] = fir.convert %[[TDESC_P1_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> index
+! CFG:      %[[TDESC_CONV:.*]] = fir.convert %[[SELECTOR_TDESC]] : (!fir.tdesc<none>) -> index
+! CFG:      %[[TDESC_CMP:.*]] = arith.cmpi eq, %[[TDESC_P1_CONV]], %[[TDESC_CONV]] : index
+! CFG:      cf.cond_br %[[TDESC_CMP]], ^[[TYPE_IS_P1_BLK:.*]], ^[[NOT_TYPE_IS_P1_BLK:.*]]
+! CFG:    ^[[NOT_TYPE_IS_P1_BLK]]:
+! CFG:      %[[TDESC_P1_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p1) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:      %[[TDESC_P1_CONV:.*]] = fir.convert %[[TDESC_P1_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> !fir.ref<none>
+! CFG:      %[[BOX_NONE:.*]] = fir.convert %[[SELECTOR]] : (!fir.class<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>) -> !fir.box<none>
+! CFG:      %[[CLASS_IS:.*]] = fir.call @_FortranAClassIs(%[[BOX_NONE]], %[[TDESC_P1_CONV]]) : (!fir.box<none>, !fir.ref<none>) -> i1
+! CFG:      cf.cond_br %[[CLASS_IS]], ^[[CLASS_IS_BLK:.*]], ^[[NOT_CLASS_IS:.*]]
+! CFG:    ^[[TYPE_IS_P1_BLK]]:
+! CFG:        cf.br ^bb[[END_SELECT_BLK:[0-9]]]
+! CFG:    ^[[NOT_CLASS_IS]]:
+! CFG:        cf.br ^bb[[DEFAULT_BLK:[0-9]]]
+! CFG:    ^[[CLASS_IS_BLK]]:
+! CFG:        cf.br ^bb[[END_SELECT_BLK]]
+! CFG:    ^bb[[DEFAULT_BLK]]:
+! CFG:        cf.br ^bb[[END_SELECT_BLK]]
+! CFG:    ^bb[[END_SELECT_BLK]]:
+! CFG:        return
+
   subroutine select_type4(a)
     class(p1), intent(in) :: a
     select type(a)
@@ -117,6 +208,39 @@ contains
 ! CHECK: ^[[P1]]
 ! CHECK: ^[[EXIT]]
 
+! CFG-LABEL: func.func @_QMselect_type_lower_testPselect_type4(
+! CFG-SAME: %[[ARG0:.*]]: !fir.class<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}) {
+! CFG:      %[[TDESC_P3_8_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p3.8) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:      %[[BOX_TDESC:.*]] = fir.box_tdesc %[[ARG0]] : (!fir.class<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>) -> !fir.tdesc<none>
+! CFG:      %[[TDESC_P3_8_CONV:.*]] = fir.convert %[[TDESC_P3_8_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> index
+! CFG:      %[[BOX_TDESC_CONV:.*]] = fir.convert %[[BOX_TDESC]] : (!fir.tdesc<none>) -> index
+! CFG:      %[[TDESC_CMP:.*]] = arith.cmpi eq, %[[TDESC_P3_8_CONV]], %[[BOX_TDESC_CONV]] : index
+! CFG:      cf.cond_br %[[TDESC_CMP]], ^[[P3_8_BLK:.*]], ^[[NOT_P3_8_BLK:.*]]
+! CFG:    ^[[NOT_P3_8_BLK]]:
+! CFG:      %[[TDESC_P3_4_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p3.4) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:      %[[BOX_TDESC:.*]] = fir.box_tdesc %[[ARG0]] : (!fir.class<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>) -> !fir.tdesc<none>
+! CFG:      %[[TDESC_P3_4_CONV:.*]] = fir.convert %[[TDESC_P3_4_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> index
+! CFG:      %[[BOX_TDESC_CONV:.*]] = fir.convert %[[BOX_TDESC]] : (!fir.tdesc<none>) -> index
+! CFG:      %[[TDESC_CMP:.*]] = arith.cmpi eq, %[[TDESC_P3_4_CONV]], %[[BOX_TDESC_CONV]] : index
+! CFG:      cf.cond_br %[[TDESC_CMP]], ^[[P3_4_BLK:.*]], ^[[NOT_P3_4_BLK:.*]]
+! CFG:    ^[[P3_8_BLK]]:
+! CFG:      _FortranAioOutputAscii
+! CFG:      cf.br ^bb[[EXIT_SELECT_BLK:[0-9]]]
+! CFG:    ^[[NOT_P3_4_BLK]]:
+! CFG:      %[[TDESC_P1_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p1) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:      %[[TDESC_P1_CONV:.*]] = fir.convert %[[TDESC_P1_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> !fir.ref<none>
+! CFG:      %[[BOX_NONE:.*]] = fir.convert %[[ARG0]] : (!fir.class<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>) -> !fir.box<none>
+! CFG:      %[[CLASS_IS:.*]] = fir.call @_FortranAClassIs(%[[BOX_NONE]], %[[TDESC_P1_CONV]]) : (!fir.box<none>, !fir.ref<none>) -> i1
+! CFG:      cf.cond_br %[[CLASS_IS]], ^[[P1_BLK:.*]], ^[[NOT_P1_BLK:.*]]
+! CFG:    ^[[P3_4_BLK]]:
+! CFG:      cf.br ^bb[[EXIT_SELECT_BLK]]
+! CFG:    ^[[NOT_P1_BLK]]:
+! CFG:      cf.br ^bb[[EXIT_SELECT_BLK]]
+! CFG:    ^[[P1_BLK]]:
+! CFG:      cf.br ^bb[[EXIT_SELECT_BLK]]
+! CFG:    ^bb[[EXIT_SELECT_BLK]]:
+! CFG:      return
+
   subroutine select_type5(a)
     class(*), intent(in) :: a
 
@@ -144,9 +268,10 @@ contains
 ! CHECK: ^[[LOG_BLK]]
 ! CHECK: ^[[DEFAULT_BLOCK]]
 
+! CFG-LABEL: func.func @_QMselect_type_lower_testPselect_type5(
 
   subroutine select_type6(a)
-    class(*), intent(out) :: a
+    class(*) :: a
 
     select type(a)
     type is (integer)
@@ -172,6 +297,114 @@ contains
 ! CHECK:  %[[C2:.*]] = arith.constant 2.000000e+00 : f32
 ! CHECK:  fir.store %[[C2]] to %[[BOX_ADDR]] : !fir.ref<f32>
 
+
+! CFG-LABEL: func.func @_QMselect_type_lower_testPselect_type6(
+! CFG-SAME: %[[ARG0:.*]]: !fir.class<none> {fir.bindc_name = "a"})
+! CFG:   %[[INT32_TYPECODE:.*]] = arith.constant 9 : i8
+! CFG:   %[[ARG0_TYPECODE:.*]] = fir.box_typecode %[[ARG0]] : (!fir.class<none>) -> i8
+! CFG:   %[[IS_TYPECODE:.*]] = arith.cmpi eq, %[[ARG0_TYPECODE]], %[[INT32_TYPECODE]] : i8
+! CFG:   cf.cond_br %[[IS_TYPECODE]], ^[[TYPE_IS_INT_BLK:.*]], ^[[TYPE_NOT_INT_BLK:.*]]
+! CFG: ^[[TYPE_NOT_INT_BLK]]:
+! CFG:   %[[FLOAT_TYPECODE:.*]] = arith.constant 27 : i8
+! CFG:   %[[ARG0_TYPECODE:.*]] = fir.box_typecode %[[ARG0]] : (!fir.class<none>) -> i8
+! CFG:   %[[IS_TYPECODE:.*]] = arith.cmpi eq, %[[ARG0_TYPECODE]], %[[FLOAT_TYPECODE]] : i8
+! CFG:   cf.cond_br %[[IS_TYPECODE]], ^[[TYPE_IS_REAL_BLK:.*]], ^[[TYPE_NOT_REAL_BLK:.*]]
+! CFG: ^[[TYPE_IS_INT_BLK]]:
+! CFG:   %[[BOX_ADDR:.*]] = fir.box_addr %[[ARG0]] : (!fir.class<none>) -> !fir.ref<i32>
+! CFG:   %[[C100:.*]] = arith.constant 100 : i32
+! CFG:   fir.store %[[C100]] to %[[BOX_ADDR]] : !fir.ref<i32>
+! CFG:   cf.br ^[[EXIT_SELECT_BLK:.*]]
+! CFG: ^[[TYPE_NOT_REAL_BLK]]:
+! CFG:   cf.br ^[[DEFAULT_BLK:.*]]
+! CFG: ^[[TYPE_IS_REAL_BLK]]:
+! CFG: %[[BOX_ADDR:.*]] = fir.box_addr %[[ARG0]] : (!fir.class<none>) -> !fir.ref<f32>
+! CFG: %[[CST:.*]] = arith.constant 2.000000e+00 : f32
+! CFG: fir.store %[[CST]] to %[[BOX_ADDR]] : !fir.ref<f32>
+! CFG: cf.br ^[[EXIT_SELECT_BLK]]
+! CFG: ^[[DEFAULT_BLK]]:
+! CFG:   fir.call @_FortranAStopStatementText
+! CFG:   fir.unreachable
+! CFG: ^[[EXIT_SELECT_BLK]]:
+! CFG   return
+
+  subroutine select_type7(a)
+    class(*), intent(out) :: a
+
+    select type(a)
+    class is (p1)
+      print*, 'CLASS IS P1'
+    class is (p2)
+      print*, 'CLASS IS P2'
+    class is (p4)
+      print*, 'CLASS IS P4'
+    class default
+      print*, 'CLASS DEFAULT'
+    end select
+  end subroutine
+
+! CHECK-LABEL: func.func @_QMselect_type_lower_testPselect_type7(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.class<none> {fir.bindc_name = "a"})
+! CHECK: fir.select_type %[[ARG0]] :
+! CHECK-SAME: !fir.class<none> [#fir.class_is<!fir.type<_QMselect_type_lower_testTp1{a:i32,b:i32}>>, ^bb1, #fir.class_is<!fir.type<_QMselect_type_lower_testTp2{a:i32,b:i32,c:i32}>>, ^bb2, #fir.class_is<!fir.type<_QMselect_type_lower_testTp4{a:i32,b:i32,c:i32,d:i32}>>, ^bb3, unit, ^bb4]
+
+! Check correct ordering of class is type guard. The expected flow should be:
+!   class is (p4) -> class is (p2) -> class is (p1) -> class default
+
+! CFG-LABEL: func.func @_QMselect_type_lower_testPselect_type7(
+! CFG-SAME: %[[ARG0:.*]]: !fir.class<none> {fir.bindc_name = "a"}) {
+! CFG:      %[[TDESC_P4_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p4) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:      %[[TDESC_P4_CONV:.*]] = fir.convert %[[TDESC_P4_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> !fir.ref<none>
+! CFG:      %[[BOX_NONE:.*]] = fir.convert %[[ARG0]] : (!fir.class<none>) -> !fir.box<none>
+! CFG:      %[[CLASS_IS_P4:.*]] = fir.call @_FortranAClassIs(%[[BOX_NONE]], %[[TDESC_P4_CONV]]) : (!fir.box<none>, !fir.ref<none>) -> i1
+! CFG:      cf.cond_br %[[CLASS_IS_P4]], ^[[CLASS_IS_P4_BLK:.*]], ^[[CLASS_NOT_P4_BLK:.*]]
+! CFG:    ^bb[[CLASS_NOT_P1_BLK:[0-9]]]:
+! CFG:      cf.br ^[[CLASS_DEFAULT_BLK:.*]]
+! CFG:    ^bb[[CLASS_IS_P1_BLK:[0-9]]]:
+! CFG:      cf.br ^[[EXIT_SELECT_BLK:.*]]
+! CFG:    ^bb[[CLASS_NOT_P2_BLK:[0-9]]]:
+! CFG:      %[[TDESC_P1_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p1) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:      %[[TDESC_P1_CONV:.*]] = fir.convert %[[TDESC_P1_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> !fir.ref<none>
+! CFG:      %[[BOX_NONE:.*]] = fir.convert %[[ARG0]] : (!fir.class<none>) -> !fir.box<none>
+! CFG:      %[[CLASS_IS_P1:.*]] = fir.call @_FortranAClassIs(%[[BOX_NONE]], %[[TDESC_P1_CONV]]) : (!fir.box<none>, !fir.ref<none>) -> i1
+! CFG:      cf.cond_br %[[CLASS_IS_P1]], ^bb[[CLASS_IS_P1_BLK]], ^bb[[CLASS_NOT_P1_BLK]]
+! CFG:    ^bb[[CLASS_IS_P2_BLK:[0-9]]]:
+! CFG:      cf.br ^[[EXIT_SELECT_BLK]]
+! CFG:    ^[[CLASS_NOT_P4_BLK]]:
+! CFG:      %[[TDESC_P2_ADDR:.*]] = fir.address_of(@_QMselect_type_lower_testE.dt.p2) : !fir.ref<!fir.type<{{.*}}>>
+! CFG:      %[[TDESC_P2_CONV:.*]] = fir.convert %[[TDESC_P2_ADDR]] : (!fir.ref<!fir.type<{{.*}}>>) -> !fir.ref<none>
+! CFG:      %[[BOX_NONE:.*]] = fir.convert %[[ARG0]] : (!fir.class<none>) -> !fir.box<none>
+! CFG:      %[[CLASS_IS_P2:.*]] = fir.call @_FortranAClassIs(%[[BOX_NONE]], %[[TDESC_P2_CONV]]) : (!fir.box<none>, !fir.ref<none>) -> i1
+! CFG:      cf.cond_br %[[CLASS_IS_P2]], ^bb[[CLASS_IS_P2_BLK]], ^bb[[CLASS_NOT_P2_BLK]]
+! CFG:   ^[[CLASS_IS_P4_BLK]]:
+! CFG:      cf.br ^[[EXIT_SELECT_BLK]]
+! CFG:   ^[[CLASS_DEFAULT_BLK]]:
+! CFG:      cf.br ^[[EXIT_SELECT_BLK]]
+! CFG:   ^[[EXIT_SELECT_BLK]]:
+! CFG:      return
+
 end module
 
+program test_select_type
+  use select_type_lower_test
+
+  integer :: a
+  real :: b
+  type(p4) :: t4
+  type(p2) :: t2
+  type(p1) :: t1
+
+  call select_type7(t4)
+  call select_type7(t2)
+  call select_type7(t1)
+
+  call select_type1(t1)
+  call select_type1(t2)
+  call select_type1(t4)
+
+  call select_type6(a)
+  print*, a
+
+  call select_type6(b)
+  print*, b
 
+end