[flang][fir] Add shape, shape_shift, and slice types.
authorEric Schweitz <eschweitz@nvidia.com>
Fri, 5 Feb 2021 20:40:39 +0000 (12:40 -0800)
committerEric Schweitz <eschweitz@nvidia.com>
Tue, 9 Feb 2021 16:55:47 +0000 (08:55 -0800)
Adding the FIR types used to describe the Fortran semantics of indexing
FIR arrays.

https://github.com/flang-compiler/f18-llvm-project/pull/267

Differential Revision: https://reviews.llvm.org/D96172

flang/include/flang/Optimizer/Dialect/FIROps.td
flang/include/flang/Optimizer/Dialect/FIRType.h
flang/lib/Optimizer/Dialect/FIRDialect.cpp
flang/lib/Optimizer/Dialect/FIRType.cpp
flang/test/Fir/fir-types.fir

index cde5372..5eee855 100644 (file)
@@ -99,7 +99,14 @@ def AnyRefOrBox : TypeConstraint<Or<[fir_ReferenceType.predicate,
     fir_HeapType.predicate, fir_PointerType.predicate, fir_BoxType.predicate]>,
     "any reference or box">;
 
-// A vector of Fortran triple notation describing a multidimensional array
+def fir_ShapeType : Type<CPred<"$_self.isa<fir::ShapeType>()">, "shape type">;
+def fir_ShapeShiftType : Type<CPred<"$_self.isa<fir::ShapeShiftType>()">,
+    "shape shift type">;
+def AnyShapeLike : TypeConstraint<Or<[fir_ShapeType.predicate,
+    fir_ShapeShiftType.predicate]>, "any legal shape type">;
+def AnyShapeType : Type<AnyShapeLike.predicate, "any legal shape type">;
+def fir_SliceType : Type<CPred<"$_self.isa<fir::SliceType>()">, "slice type">;
+
 def AnyEmboxLike : TypeConstraint<Or<[AnySignlessInteger.predicate,
     Index.predicate, fir_IntegerType.predicate]>,
     "any legal embox argument type">;
index be984ed..a10aef5 100644 (file)
@@ -54,6 +54,9 @@ struct RealTypeStorage;
 struct RecordTypeStorage;
 struct ReferenceTypeStorage;
 struct SequenceTypeStorage;
+struct ShapeTypeStorage;
+struct ShapeShiftTypeStorage;
+struct SliceTypeStorage;
 struct TypeDescTypeStorage;
 struct VectorTypeStorage;
 } // namespace detail
@@ -216,6 +219,41 @@ public:
                                                           mlir::Type eleTy);
 };
 
+/// Type of a vector of runtime values that define the shape of a
+/// multidimensional array object. The vector is the extents of each array
+/// dimension. The rank of a ShapeType must be at least 1.
+class ShapeType : public mlir::Type::TypeBase<ShapeType, mlir::Type,
+                                              detail::ShapeTypeStorage> {
+public:
+  using Base::Base;
+  static ShapeType get(mlir::MLIRContext *ctx, unsigned rank);
+  unsigned getRank() const;
+};
+
+/// Type of a vector of runtime values that define the shape and the origin of a
+/// multidimensional array object. The vector is of pairs, origin offset and
+/// extent, of each array dimension. The rank of a ShapeShiftType must be at
+/// least 1.
+class ShapeShiftType
+    : public mlir::Type::TypeBase<ShapeShiftType, mlir::Type,
+                                  detail::ShapeShiftTypeStorage> {
+public:
+  using Base::Base;
+  static ShapeShiftType get(mlir::MLIRContext *ctx, unsigned rank);
+  unsigned getRank() const;
+};
+
+/// Type of a vector that represents an array slice operation on an array.
+/// Fortran slices are triples of lower bound, upper bound, and stride. The rank
+/// of a SliceType must be at least 1.
+class SliceType : public mlir::Type::TypeBase<SliceType, mlir::Type,
+                                              detail::SliceTypeStorage> {
+public:
+  using Base::Base;
+  static SliceType get(mlir::MLIRContext *ctx, unsigned rank);
+  unsigned getRank() const;
+};
+
 /// The type of a field name. Implementations may defer the layout of a Fortran
 /// derived type until runtime. This implies that the runtime must be able to
 /// determine the offset of fields within the entity.
index c424b98..24df36c 100644 (file)
@@ -5,6 +5,10 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIRAttr.h"
@@ -18,7 +22,8 @@ fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx)
   addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
            FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
            PointerType, RealType, RecordType, ReferenceType, SequenceType,
-           TypeDescType, fir::VectorType>();
+           ShapeType, ShapeShiftType, SliceType, TypeDescType,
+           fir::VectorType>();
   addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr, OpaqueAttr,
                 PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
   addOperations<
index 3124482..863babe 100644 (file)
@@ -112,6 +112,21 @@ fir::ComplexType parseComplex(mlir::DialectAsmParser &parser) {
   return parseKindSingleton<fir::ComplexType>(parser);
 }
 
+// `shape` `<` rank `>`
+ShapeType parseShape(mlir::DialectAsmParser &parser) {
+  return parseRankSingleton<ShapeType>(parser);
+}
+
+// `shapeshift` `<` rank `>`
+ShapeShiftType parseShapeShift(mlir::DialectAsmParser &parser) {
+  return parseRankSingleton<ShapeShiftType>(parser);
+}
+
+// `slice` `<` rank `>`
+SliceType parseSlice(mlir::DialectAsmParser &parser) {
+  return parseRankSingleton<SliceType>(parser);
+}
+
 // `field`
 FieldType parseField(mlir::DialectAsmParser &parser) {
   return FieldType::get(parser.getBuilder().getContext());
@@ -215,7 +230,9 @@ static bool isaIntegerType(mlir::Type ty) {
 
 bool verifyRecordMemberType(mlir::Type ty) {
   return !(ty.isa<BoxType>() || ty.isa<BoxCharType>() ||
-           ty.isa<BoxProcType>() || ty.isa<FieldType>() || ty.isa<LenType>() ||
+           ty.isa<BoxProcType>() || ty.isa<ShapeType>() ||
+           ty.isa<ShapeShiftType>() || ty.isa<SliceType>() ||
+           ty.isa<FieldType>() || ty.isa<LenType>() ||
            ty.isa<ReferenceType>() || ty.isa<TypeDescType>());
 }
 
@@ -369,6 +386,12 @@ mlir::Type fir::parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser) {
     return parseReal(parser);
   if (typeNameLit == "ref")
     return parseReference(parser, loc);
+  if (typeNameLit == "shape")
+    return parseShape(parser);
+  if (typeNameLit == "shapeshift")
+    return parseShapeShift(parser);
+  if (typeNameLit == "slice")
+    return parseSlice(parser);
   if (typeNameLit == "tdesc")
     return parseTypeDesc(parser, loc);
   if (typeNameLit == "type")
@@ -420,6 +443,75 @@ private:
       : kind{kind}, len{len} {}
 };
 
+struct ShapeTypeStorage : public mlir::TypeStorage {
+  using KeyTy = unsigned;
+
+  static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(key); }
+
+  bool operator==(const KeyTy &key) const { return key == getRank(); }
+
+  static ShapeTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
+                                     unsigned rank) {
+    auto *storage = allocator.allocate<ShapeTypeStorage>();
+    return new (storage) ShapeTypeStorage{rank};
+  }
+
+  unsigned getRank() const { return rank; }
+
+protected:
+  unsigned rank;
+
+private:
+  ShapeTypeStorage() = delete;
+  explicit ShapeTypeStorage(unsigned rank) : rank{rank} {}
+};
+
+struct ShapeShiftTypeStorage : public mlir::TypeStorage {
+  using KeyTy = unsigned;
+
+  static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(key); }
+
+  bool operator==(const KeyTy &key) const { return key == getRank(); }
+
+  static ShapeShiftTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
+                                          unsigned rank) {
+    auto *storage = allocator.allocate<ShapeShiftTypeStorage>();
+    return new (storage) ShapeShiftTypeStorage{rank};
+  }
+
+  unsigned getRank() const { return rank; }
+
+protected:
+  unsigned rank;
+
+private:
+  ShapeShiftTypeStorage() = delete;
+  explicit ShapeShiftTypeStorage(unsigned rank) : rank{rank} {}
+};
+
+struct SliceTypeStorage : public mlir::TypeStorage {
+  using KeyTy = unsigned;
+
+  static unsigned hashKey(const KeyTy &key) { return llvm::hash_combine(key); }
+
+  bool operator==(const KeyTy &key) const { return key == getRank(); }
+
+  static SliceTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
+                                     unsigned rank) {
+    auto *storage = allocator.allocate<SliceTypeStorage>();
+    return new (storage) SliceTypeStorage{rank};
+  }
+
+  unsigned getRank() const { return rank; }
+
+protected:
+  unsigned rank;
+
+private:
+  SliceTypeStorage() = delete;
+  explicit SliceTypeStorage(unsigned rank) : rank{rank} {}
+};
+
 /// The type of a derived type part reference
 struct FieldTypeStorage : public mlir::TypeStorage {
   using KeyTy = KindTy;
@@ -894,11 +986,13 @@ bool isa_box_type(mlir::Type t) {
 }
 
 bool isa_passbyref_type(mlir::Type t) {
-  return t.isa<ReferenceType>() || isa_box_type(t);
+  return t.isa<ReferenceType>() || isa_box_type(t) ||
+         t.isa<mlir::FunctionType>();
 }
 
 bool isa_aggregate(mlir::Type t) {
-  return t.isa<SequenceType>() || t.isa<RecordType>();
+  return t.isa<SequenceType>() || t.isa<RecordType>() ||
+         t.isa<mlir::TupleType>();
 }
 
 mlir::Type dyn_cast_ptrEleTy(mlir::Type t) {
@@ -1036,8 +1130,10 @@ mlir::Type fir::ReferenceType::getEleTy() const {
 mlir::LogicalResult
 fir::ReferenceType::verifyConstructionInvariants(mlir::Location loc,
                                                  mlir::Type eleTy) {
-  if (eleTy.isa<FieldType>() || eleTy.isa<LenType>() ||
-      eleTy.isa<ReferenceType>() || eleTy.isa<TypeDescType>())
+  if (eleTy.isa<ShapeType>() || eleTy.isa<ShapeShiftType>() ||
+      eleTy.isa<SliceType>() || eleTy.isa<FieldType>() ||
+      eleTy.isa<LenType>() || eleTy.isa<ReferenceType>() ||
+      eleTy.isa<TypeDescType>())
     return mlir::emitError(loc, "cannot build a reference to type: ")
            << eleTy << '\n';
   return mlir::success();
@@ -1056,10 +1152,11 @@ mlir::Type fir::PointerType::getEleTy() const {
 
 static bool canBePointerOrHeapElementType(mlir::Type eleTy) {
   return eleTy.isa<BoxType>() || eleTy.isa<BoxCharType>() ||
-         eleTy.isa<BoxProcType>() || eleTy.isa<FieldType>() ||
-         eleTy.isa<LenType>() || eleTy.isa<HeapType>() ||
-         eleTy.isa<PointerType>() || eleTy.isa<ReferenceType>() ||
-         eleTy.isa<TypeDescType>();
+         eleTy.isa<BoxProcType>() || eleTy.isa<ShapeType>() ||
+         eleTy.isa<ShapeShiftType>() || eleTy.isa<SliceType>() ||
+         eleTy.isa<FieldType>() || eleTy.isa<LenType>() ||
+         eleTy.isa<HeapType>() || eleTy.isa<PointerType>() ||
+         eleTy.isa<ReferenceType>() || eleTy.isa<TypeDescType>();
 }
 
 mlir::LogicalResult
@@ -1144,8 +1241,9 @@ mlir::LogicalResult fir::SequenceType::verifyConstructionInvariants(
     mlir::AffineMapAttr map) {
   // DIMENSION attribute can only be applied to an intrinsic or record type
   if (eleTy.isa<BoxType>() || eleTy.isa<BoxCharType>() ||
-      eleTy.isa<BoxProcType>() || eleTy.isa<FieldType>() ||
-      eleTy.isa<LenType>() || eleTy.isa<HeapType>() ||
+      eleTy.isa<BoxProcType>() || eleTy.isa<ShapeType>() ||
+      eleTy.isa<ShapeShiftType>() || eleTy.isa<SliceType>() ||
+      eleTy.isa<FieldType>() || eleTy.isa<LenType>() || eleTy.isa<HeapType>() ||
       eleTy.isa<PointerType>() || eleTy.isa<ReferenceType>() ||
       eleTy.isa<TypeDescType>() || eleTy.isa<fir::VectorType>() ||
       eleTy.isa<SequenceType>())
@@ -1154,27 +1252,6 @@ mlir::LogicalResult fir::SequenceType::verifyConstructionInvariants(
   return mlir::success();
 }
 
-//===----------------------------------------------------------------------===//
-// Vector type
-//===----------------------------------------------------------------------===//
-
-fir::VectorType fir::VectorType::get(uint64_t len, mlir::Type eleTy) {
-  return Base::get(eleTy.getContext(), len, eleTy);
-}
-
-mlir::Type fir::VectorType::getEleTy() const { return getImpl()->getEleTy(); }
-
-uint64_t fir::VectorType::getLen() const { return getImpl()->getLen(); }
-
-mlir::LogicalResult
-fir::VectorType::verifyConstructionInvariants(mlir::Location loc, uint64_t len,
-                                              mlir::Type eleTy) {
-  if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy)))
-    return mlir::emitError(loc, "cannot build a vector of type ")
-           << eleTy << '\n';
-  return mlir::success();
-}
-
 // compare if two shapes are equivalent
 bool fir::operator==(const SequenceType::Shape &sh_1,
                      const SequenceType::Shape &sh_2) {
@@ -1195,6 +1272,31 @@ llvm::hash_code fir::hash_value(const SequenceType::Shape &sh) {
   return llvm::hash_combine(0);
 }
 
+// Shape
+
+ShapeType fir::ShapeType::get(mlir::MLIRContext *ctxt, unsigned rank) {
+  return Base::get(ctxt, rank);
+}
+
+unsigned fir::ShapeType::getRank() const { return getImpl()->getRank(); }
+
+// Shapeshift
+
+ShapeShiftType fir::ShapeShiftType::get(mlir::MLIRContext *ctxt,
+                                        unsigned rank) {
+  return Base::get(ctxt, rank);
+}
+
+unsigned fir::ShapeShiftType::getRank() const { return getImpl()->getRank(); }
+
+// Slice
+
+SliceType fir::SliceType::get(mlir::MLIRContext *ctxt, unsigned rank) {
+  return Base::get(ctxt, rank);
+}
+
+unsigned fir::SliceType::getRank() const { return getImpl()->getRank(); }
+
 /// RecordType
 ///
 /// This type captures a Fortran "derived type"
@@ -1237,9 +1339,9 @@ mlir::Type fir::RecordType::getType(llvm::StringRef ident) {
   return {};
 }
 
-/// Type descriptor type
-///
-/// This is the type of a type descriptor object (similar to a class instance)
+//===----------------------------------------------------------------------===//
+// Type descriptor type
+//===----------------------------------------------------------------------===//
 
 TypeDescType fir::TypeDescType::get(mlir::Type ofType) {
   assert(!ofType.isa<ReferenceType>());
@@ -1252,14 +1354,36 @@ mlir::LogicalResult
 fir::TypeDescType::verifyConstructionInvariants(mlir::Location loc,
                                                 mlir::Type eleTy) {
   if (eleTy.isa<BoxType>() || eleTy.isa<BoxCharType>() ||
-      eleTy.isa<BoxProcType>() || eleTy.isa<FieldType>() ||
-      eleTy.isa<LenType>() || eleTy.isa<ReferenceType>() ||
-      eleTy.isa<TypeDescType>())
+      eleTy.isa<BoxProcType>() || eleTy.isa<ShapeType>() ||
+      eleTy.isa<ShapeShiftType>() || eleTy.isa<SliceType>() ||
+      eleTy.isa<FieldType>() || eleTy.isa<LenType>() ||
+      eleTy.isa<ReferenceType>() || eleTy.isa<TypeDescType>())
     return mlir::emitError(loc, "cannot build a type descriptor of type: ")
            << eleTy << '\n';
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// Vector type
+//===----------------------------------------------------------------------===//
+
+fir::VectorType fir::VectorType::get(uint64_t len, mlir::Type eleTy) {
+  return Base::get(eleTy.getContext(), len, eleTy);
+}
+
+mlir::Type fir::VectorType::getEleTy() const { return getImpl()->getEleTy(); }
+
+uint64_t fir::VectorType::getLen() const { return getImpl()->getLen(); }
+
+mlir::LogicalResult
+fir::VectorType::verifyConstructionInvariants(mlir::Location loc, uint64_t len,
+                                              mlir::Type eleTy) {
+  if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy)))
+    return mlir::emitError(loc, "cannot build a vector of type ")
+           << eleTy << '\n';
+  return mlir::success();
+}
+
 namespace {
 
 void printBounds(llvm::raw_ostream &os, const SequenceType::Shape &bounds) {
@@ -1321,10 +1445,12 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
     return;
   }
   if (auto type = ty.dyn_cast<fir::ComplexType>()) {
+    // Fortran intrinsic type COMPLEX
     os << "complex<" << type.getFKind() << '>';
     return;
   }
   if (auto type = ty.dyn_cast<RecordType>()) {
+    // Fortran derived type
     os << "type<" << type.getName();
     if (!recordTypeVisited.count(type.uniqueKey())) {
       recordTypeVisited.insert(type.uniqueKey());
@@ -1351,6 +1477,18 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
     os << '>';
     return;
   }
+  if (auto type = ty.dyn_cast<ShapeType>()) {
+    os << "shape<" << type.getRank() << '>';
+    return;
+  }
+  if (auto type = ty.dyn_cast<ShapeShiftType>()) {
+    os << "shapeshift<" << type.getRank() << '>';
+    return;
+  }
+  if (auto type = ty.dyn_cast<SliceType>()) {
+    os << "slice<" << type.getRank() << '>';
+    return;
+  }
   if (ty.isa<FieldType>()) {
     os << "field";
     return;
@@ -1362,6 +1500,7 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
     return;
   }
   if (auto type = ty.dyn_cast<fir::IntegerType>()) {
+    // Fortran intrinsic type INTEGER
     os << "int<" << type.getFKind() << '>';
     return;
   }
@@ -1370,6 +1509,7 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
     return;
   }
   if (auto type = ty.dyn_cast<LogicalType>()) {
+    // Fortran intrinsic type LOGICAL
     os << "logical<" << type.getFKind() << '>';
     return;
   }
@@ -1380,6 +1520,7 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
     return;
   }
   if (auto type = ty.dyn_cast<fir::RealType>()) {
+    // Fortran intrinsic types REAL and DOUBLE PRECISION
     os << "real<" << type.getFKind() << '>';
     return;
   }
index 0a6772d..2a46301 100644 (file)
@@ -71,10 +71,16 @@ func private @box4() -> !fir.box<none>
 func private @box5() -> !fir.box<!fir.type<derived3{f:f32}>>
 
 // FIR misc. types
+// CHECK-LABEL: func private @oth1() -> !fir.shape<1>
 // CHECK-LABEL: func private @oth2() -> !fir.field
 // CHECK-LABEL: func private @oth3() -> !fir.tdesc<!fir.type<derived7{f1:f32,f2:f32}>>
+// CHECK-LABEL: func private @oth4() -> !fir.shapeshift<15>
+// CHECK-LABEL: func private @oth5() -> !fir.slice<8>
+func private @oth1() -> !fir.shape<1>
 func private @oth2() -> !fir.field
 func private @oth3() -> !fir.tdesc<!fir.type<derived7{f1:f32,f2:f32}>>
+func private @oth4() -> !fir.shapeshift<15>
+func private @oth5() -> !fir.slice<8>
 
 // FIR vector
 // CHECK-LABEL: func private @vecty(i1) -> !fir.vector<10:i32>