[mlir] forward data layout query to scoping op in absence of specification
authorAlex Zinenko <zinenko@google.com>
Mon, 22 Mar 2021 13:58:13 +0000 (14:58 +0100)
committerAlex Zinenko <zinenko@google.com>
Wed, 24 Mar 2021 14:13:41 +0000 (15:13 +0100)
Even if the layout specification is missing from an op that supports it, the op
is still expected to provide meaningful responses to data layout queries.
Forward them to the op instead of directly calling the default implementation.

Depends On D98524

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D98525

mlir/lib/Interfaces/DataLayoutInterfaces.cpp
mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp

index 14fb485..c91a723 100644 (file)
@@ -280,52 +280,48 @@ static unsigned cachedLookup(Type t, DenseMap<Type, unsigned> &cache,
 unsigned mlir::DataLayout::getTypeSize(Type t) const {
   checkValid();
   return cachedLookup(t, sizes, [&](Type ty) {
-    if (originalLayout) {
-      DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID());
-      if (auto iface = dyn_cast<DataLayoutOpInterface>(scope))
-        return iface.getTypeSize(ty, *this, list);
-      return detail::getDefaultTypeSize(ty, *this, list);
-    }
-    return detail::getDefaultTypeSize(ty, *this, {});
+    DataLayoutEntryList list;
+    if (originalLayout)
+      list = originalLayout.getSpecForType(ty.getTypeID());
+    if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+      return iface.getTypeSize(ty, *this, list);
+    return detail::getDefaultTypeSize(ty, *this, list);
   });
 }
 
 unsigned mlir::DataLayout::getTypeSizeInBits(Type t) const {
   checkValid();
   return cachedLookup(t, bitsizes, [&](Type ty) {
-    if (originalLayout) {
-      DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID());
-      if (auto iface = dyn_cast<DataLayoutOpInterface>(scope))
-        return iface.getTypeSizeInBits(ty, *this, list);
-      return detail::getDefaultTypeSizeInBits(ty, *this, list);
-    }
-    return detail::getDefaultTypeSizeInBits(ty, *this, {});
+    DataLayoutEntryList list;
+    if (originalLayout)
+      list = originalLayout.getSpecForType(ty.getTypeID());
+    if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+      return iface.getTypeSizeInBits(ty, *this, list);
+    return detail::getDefaultTypeSizeInBits(ty, *this, list);
   });
 }
 
 unsigned mlir::DataLayout::getTypeABIAlignment(Type t) const {
   checkValid();
   return cachedLookup(t, abiAlignments, [&](Type ty) {
-    if (originalLayout) {
-      DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID());
-      if (auto iface = dyn_cast<DataLayoutOpInterface>(scope))
-        return iface.getTypeABIAlignment(ty, *this, list);
-      return detail::getDefaultABIAlignment(ty, *this, list);
-    }
-    return detail::getDefaultABIAlignment(ty, *this, {});
+    DataLayoutEntryList list;
+    if (originalLayout)
+      list = originalLayout.getSpecForType(ty.getTypeID());
+    if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+      return iface.getTypeABIAlignment(ty, *this, list);
+    return detail::getDefaultABIAlignment(ty, *this, list);
   });
 }
 
 unsigned mlir::DataLayout::getTypePreferredAlignment(Type t) const {
   checkValid();
   return cachedLookup(t, preferredAlignments, [&](Type ty) {
-    if (originalLayout) {
-      DataLayoutEntryList list = originalLayout.getSpecForType(ty.getTypeID());
-      if (auto iface = dyn_cast<DataLayoutOpInterface>(scope))
-        return iface.getTypePreferredAlignment(ty, *this, list);
-      return detail::getDefaultPreferredAlignment(ty, *this, list);
-    }
-    return detail::getDefaultPreferredAlignment(ty, *this, {});
+    DataLayoutEntryList list;
+    if (originalLayout)
+      list = originalLayout.getSpecForType(ty.getTypeID());
+    if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+      return iface.getTypePreferredAlignment(ty, *this, list);
+    return detail::getDefaultPreferredAlignment(ty, *this, list);
   });
 }
 
index e9d69f0..2878391 100644 (file)
@@ -227,7 +227,7 @@ struct DLTestDialect : Dialect {
 
 TEST(DataLayout, FallbackDefault) {
   const char *ir = R"MLIR(
-"dltest.op_with_layout"() : () -> ()
+module {}
   )MLIR";
 
   DialectRegistry registry;
@@ -235,9 +235,7 @@ TEST(DataLayout, FallbackDefault) {
   MLIRContext ctx(registry);
 
   OwningModuleRef module = parseSourceString(ir, &ctx);
-  auto op =
-      cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
-  DataLayout layout(op);
+  DataLayout layout(module.get());
   EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 6u);
   EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 2u);
   EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 42u);
@@ -248,6 +246,29 @@ TEST(DataLayout, FallbackDefault) {
   EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 2u);
 }
 
+TEST(DataLayout, NullSpec) {
+  const char *ir = R"MLIR(
+"dltest.op_with_layout"() : () -> ()
+  )MLIR";
+
+  DialectRegistry registry;
+  registry.insert<DLTIDialect, DLTestDialect>();
+  MLIRContext ctx(registry);
+
+  OwningModuleRef module = parseSourceString(ir, &ctx);
+  auto op =
+      cast<DataLayoutOpInterface>(module->getBody()->getOperations().front());
+  DataLayout layout(op);
+  EXPECT_EQ(layout.getTypeSize(IntegerType::get(&ctx, 42)), 42u);
+  EXPECT_EQ(layout.getTypeSize(Float16Type::get(&ctx)), 16u);
+  EXPECT_EQ(layout.getTypeSizeInBits(IntegerType::get(&ctx, 42)), 8u * 42u);
+  EXPECT_EQ(layout.getTypeSizeInBits(Float16Type::get(&ctx)), 8u * 16u);
+  EXPECT_EQ(layout.getTypeABIAlignment(IntegerType::get(&ctx, 42)), 64u);
+  EXPECT_EQ(layout.getTypeABIAlignment(Float16Type::get(&ctx)), 16u);
+  EXPECT_EQ(layout.getTypePreferredAlignment(IntegerType::get(&ctx, 42)), 128u);
+  EXPECT_EQ(layout.getTypePreferredAlignment(Float16Type::get(&ctx)), 32u);
+}
+
 TEST(DataLayout, EmptySpec) {
   const char *ir = R"MLIR(
 "dltest.op_with_layout"() { dltest.layout = #dltest.spec< > } : () -> ()