From d1cb68525c4a0127f2d823d2c8c49791f55d3553 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 19 Apr 2023 11:00:48 +0900 Subject: [PATCH] [mlir][IR] Remove ShapedType::getSizeInBits This function returns incorrect values for memrefs and vectors due to "widening". Differential Revision: https://reviews.llvm.org/D148501 --- mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 9 --------- .../Conversion/VectorToSPIRV/VectorToSPIRV.cpp | 5 ++++- mlir/lib/IR/BuiltinTypeInterfaces.cpp | 15 --------------- mlir/test/lib/Dialect/Test/TestOps.td | 4 ++-- mlir/test/mlir-tblgen/op-derived-attribute.mlir | 6 +++--- 5 files changed, 9 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index f2b1fa34bc39..bb38985715c0 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -99,15 +99,6 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> { /// Return the number of elements present in the given shape. static int64_t getNumElements(ArrayRef shape); - - /// Returns the total amount of bits occupied by a value of this type. This - /// does not take into account any memory layout or widening constraints, - /// e.g. a vector<3xi57> may report to occupy 3x57=171 bit, even though in - /// practice it will likely be stored as in a 4xi64 vector register. Fails - /// with an assertion if the size cannot be computed statically, e.g. if the - /// type has a dynamic shape or if its elemental type does not have a known - /// bit width. - int64_t getSizeInBits() const; }]; let extraSharedClassDeclaration = [{ diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index dae90f26199a..50017b7bcef9 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -40,8 +40,11 @@ static uint64_t getFirstIntValue(ArrayAttr attr) { /// Returns the number of bits for the given scalar/vector type. static int getNumBits(Type type) { + // TODO: This does not take into account any memory layout or widening + // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even + // though in practice it will likely be stored as in a 4xi64 vector register. if (auto vectorType = type.dyn_cast()) - return vectorType.cast().getSizeInBits(); + return vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); return type.getIntOrFloatBitWidth(); } diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp index 88791fc66fbf..ab9e65b5edfe 100644 --- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp @@ -33,18 +33,3 @@ int64_t ShapedType::getNumElements(ArrayRef shape) { } return num; } - -int64_t ShapedType::getSizeInBits() const { - assert(hasStaticShape() && - "cannot get the bit size of an aggregate with a dynamic shape"); - - auto elementType = getElementType(); - if (elementType.isIntOrFloat()) - return elementType.getIntOrFloatBitWidth() * getNumElements(); - - if (auto complexType = elementType.dyn_cast()) { - elementType = complexType.getElementType(); - return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; - } - return getNumElements() * elementType.cast().getSizeInBits(); -} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 0306f0ed02f9..9381409f0784 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -259,8 +259,8 @@ def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> { let results = (outs AnyTensor:$output); DerivedTypeAttr element_dtype = DerivedTypeAttr<"return getElementTypeOrSelf(getOutput().getType());">; - DerivedAttr size = DerivedAttr<"int", - "return getOutput().getType().cast().getSizeInBits();", + DerivedAttr num_elements = DerivedAttr<"int", + "return getOutput().getType().cast().getNumElements();", "$_builder.getI32IntegerAttr($_self)">; } diff --git a/mlir/test/mlir-tblgen/op-derived-attribute.mlir b/mlir/test/mlir-tblgen/op-derived-attribute.mlir index 27fa3ab821ac..2b0e6ed4994b 100644 --- a/mlir/test/mlir-tblgen/op-derived-attribute.mlir +++ b/mlir/test/mlir-tblgen/op-derived-attribute.mlir @@ -3,15 +3,15 @@ // CHECK-LABEL: verifyDerivedAttributes func.func @verifyDerivedAttributes() { // expected-remark @+2 {{element_dtype = f32}} - // expected-remark @+1 {{size = 320}} + // expected-remark @+1 {{num_elements = 10}} %0 = "test.derived_type_attr"() : () -> tensor<10xf32> // expected-remark @+2 {{element_dtype = i79}} - // expected-remark @+1 {{size = 948}} + // expected-remark @+1 {{num_elements = 12}} %1 = "test.derived_type_attr"() : () -> tensor<12xi79> // expected-remark @+2 {{element_dtype = complex}} - // expected-remark @+1 {{size = 768}} + // expected-remark @+1 {{num_elements = 12}} %2 = "test.derived_type_attr"() : () -> tensor<12xcomplex> return -- 2.34.1