[mlir][IR] Remove ShapedType::getSizeInBits
authorMatthias Springer <springerm@google.com>
Wed, 19 Apr 2023 02:00:48 +0000 (11:00 +0900)
committerMatthias Springer <springerm@google.com>
Wed, 19 Apr 2023 02:01:33 +0000 (11:01 +0900)
This function returns incorrect values for memrefs and vectors due to "widening".

Differential Revision: https://reviews.llvm.org/D148501

mlir/include/mlir/IR/BuiltinTypeInterfaces.td
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
mlir/lib/IR/BuiltinTypeInterfaces.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-derived-attribute.mlir

index f2b1fa34bc391716cbfa02db5fc776bf749aae06..bb38985715c097702e27166031c63aba1fb6bce0 100644 (file)
@@ -99,15 +99,6 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
 
     /// Return the number of elements present in the given shape.
     static int64_t getNumElements(ArrayRef<int64_t> shape);
-
-    /// Returns the total amount of bits occupied by a value of this type. This
-    /// does not take into account any memory layout or widening constraints,
-    /// e.g. a vector<3xi57> may report to occupy 3x57=171 bit, even though in
-    /// practice it will likely be stored as in a 4xi64 vector register. Fails
-    /// with an assertion if the size cannot be computed statically, e.g. if the
-    /// type has a dynamic shape or if its elemental type does not have a known
-    /// bit width.
-    int64_t getSizeInBits() const;
   }];
 
   let extraSharedClassDeclaration = [{
index dae90f26199a4007f2c1cc742c15c704b9ae7a9f..50017b7bcef9b754144cfeea44f7724e89ad9314 100644 (file)
@@ -40,8 +40,11 @@ static uint64_t getFirstIntValue(ArrayAttr attr) {
 
 /// Returns the number of bits for the given scalar/vector type.
 static int getNumBits(Type type) {
+  // TODO: This does not take into account any memory layout or widening
+  // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
+  // though in practice it will likely be stored as in a 4xi64 vector register.
   if (auto vectorType = type.dyn_cast<VectorType>())
-    return vectorType.cast<ShapedType>().getSizeInBits();
+    return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
   return type.getIntOrFloatBitWidth();
 }
 
index 88791fc66fbf2d40df57ae90f6152a5ee7b2a58c..ab9e65b5edfed3f339cef317e2d4a75737f949b6 100644 (file)
@@ -33,18 +33,3 @@ int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
   }
   return num;
 }
-
-int64_t ShapedType::getSizeInBits() const {
-  assert(hasStaticShape() &&
-         "cannot get the bit size of an aggregate with a dynamic shape");
-
-  auto elementType = getElementType();
-  if (elementType.isIntOrFloat())
-    return elementType.getIntOrFloatBitWidth() * getNumElements();
-
-  if (auto complexType = elementType.dyn_cast<ComplexType>()) {
-    elementType = complexType.getElementType();
-    return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
-  }
-  return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
-}
index 0306f0ed02f9931cdaeb7d9868945926830473c1..9381409f078445374e3065e6339371f005326498 100644 (file)
@@ -259,8 +259,8 @@ def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> {
   let results = (outs AnyTensor:$output);
   DerivedTypeAttr element_dtype =
     DerivedTypeAttr<"return getElementTypeOrSelf(getOutput().getType());">;
-  DerivedAttr size = DerivedAttr<"int",
-    "return getOutput().getType().cast<ShapedType>().getSizeInBits();",
+  DerivedAttr num_elements = DerivedAttr<"int",
+    "return getOutput().getType().cast<ShapedType>().getNumElements();",
     "$_builder.getI32IntegerAttr($_self)">;
 }
 
index 27fa3ab821ac019204d3b3eaebb0ab481c039217..2b0e6ed4994b6d9203b7c13f2dc2b17ddc5fa8c7 100644 (file)
@@ -3,15 +3,15 @@
 // CHECK-LABEL: verifyDerivedAttributes
 func.func @verifyDerivedAttributes() {
   // expected-remark @+2 {{element_dtype = f32}}
-  // expected-remark @+1 {{size = 320}}
+  // expected-remark @+1 {{num_elements = 10}}
   %0 = "test.derived_type_attr"() : () -> tensor<10xf32>
 
   // expected-remark @+2 {{element_dtype = i79}}
-  // expected-remark @+1 {{size = 948}}
+  // expected-remark @+1 {{num_elements = 12}}
   %1 = "test.derived_type_attr"() : () -> tensor<12xi79>
 
   // expected-remark @+2 {{element_dtype = complex<f32>}}
-  // expected-remark @+1 {{size = 768}}
+  // expected-remark @+1 {{num_elements = 12}}
   %2 = "test.derived_type_attr"() : () -> tensor<12xcomplex<f32>>
 
   return