/// Return the number of elements present in the given shape.
static int64_t getNumElements(ArrayRef<int64_t> shape);
-
- /// Returns the total amount of bits occupied by a value of this type. This
- /// does not take into account any memory layout or widening constraints,
- /// e.g. a vector<3xi57> may report to occupy 3x57=171 bit, even though in
- /// practice it will likely be stored as in a 4xi64 vector register. Fails
- /// with an assertion if the size cannot be computed statically, e.g. if the
- /// type has a dynamic shape or if its elemental type does not have a known
- /// bit width.
- int64_t getSizeInBits() const;
}];
let extraSharedClassDeclaration = [{
/// Returns the number of bits for the given scalar/vector type.
static int getNumBits(Type type) {
+ // TODO: This does not take into account any memory layout or widening
+ // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even
+ // though in practice it will likely be stored as in a 4xi64 vector register.
if (auto vectorType = type.dyn_cast<VectorType>())
- return vectorType.cast<ShapedType>().getSizeInBits();
+ return vectorType.getNumElements() * vectorType.getElementTypeBitWidth();
return type.getIntOrFloatBitWidth();
}
}
return num;
}
-
-int64_t ShapedType::getSizeInBits() const {
- assert(hasStaticShape() &&
- "cannot get the bit size of an aggregate with a dynamic shape");
-
- auto elementType = getElementType();
- if (elementType.isIntOrFloat())
- return elementType.getIntOrFloatBitWidth() * getNumElements();
-
- if (auto complexType = elementType.dyn_cast<ComplexType>()) {
- elementType = complexType.getElementType();
- return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
- }
- return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
-}
let results = (outs AnyTensor:$output);
DerivedTypeAttr element_dtype =
DerivedTypeAttr<"return getElementTypeOrSelf(getOutput().getType());">;
- DerivedAttr size = DerivedAttr<"int",
- "return getOutput().getType().cast<ShapedType>().getSizeInBits();",
+ DerivedAttr num_elements = DerivedAttr<"int",
+ "return getOutput().getType().cast<ShapedType>().getNumElements();",
"$_builder.getI32IntegerAttr($_self)">;
}
// CHECK-LABEL: verifyDerivedAttributes
func.func @verifyDerivedAttributes() {
// expected-remark @+2 {{element_dtype = f32}}
- // expected-remark @+1 {{size = 320}}
+ // expected-remark @+1 {{num_elements = 10}}
%0 = "test.derived_type_attr"() : () -> tensor<10xf32>
// expected-remark @+2 {{element_dtype = i79}}
- // expected-remark @+1 {{size = 948}}
+ // expected-remark @+1 {{num_elements = 12}}
%1 = "test.derived_type_attr"() : () -> tensor<12xi79>
// expected-remark @+2 {{element_dtype = complex<f32>}}
- // expected-remark @+1 {{size = 768}}
+ // expected-remark @+1 {{num_elements = 12}}
%2 = "test.derived_type_attr"() : () -> tensor<12xcomplex<f32>>
return