Exclude all ShapedType subclasses other than TensorType subclasses from having...
authorGeoffrey Martin-Noble <gcmn@google.com>
Wed, 29 May 2019 23:07:17 +0000 (16:07 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:09:42 +0000 (20:09 -0700)
    The current logic assumes that ShapedType indicates a vector or tensor, which will not be true soon when MemRef subclasses ShapedType

--

PiperOrigin-RevId: 250586364

mlir/lib/IR/StandardTypes.cpp

index 6c0d740..300cec9 100644 (file)
@@ -147,11 +147,12 @@ int64_t ShapedType::getSizeInBits() const {
   if (elementType.isIntOrFloat())
     return elementType.getIntOrFloatBitWidth() * getNumElements();
 
-  // Tensors can have vectors and other tensors as elements, vectors cannot.
-  assert(!isa<VectorType>() && "unsupported vector element type");
-  auto elementShapedType = elementType.dyn_cast<ShapedType>();
-  assert(elementShapedType && "unsupported tensor element type");
-  return getNumElements() * elementShapedType.getSizeInBits();
+  // Tensors can have vectors and other tensors as elements, other shaped types
+  // cannot.
+  assert(isa<TensorType>() && "unsupported element type");
+  assert((elementType.isa<VectorType>() || elementType.isa<TensorType>()) &&
+         "unsupported tensor element type");
+  return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
 }
 
 ArrayRef<int64_t> ShapedType::getShape() const {