From 1c681a7cafc26266c511fd590da24855f0622fc3 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Wed, 29 May 2019 16:07:17 -0700 Subject: [PATCH] Exclude all ShapedType subclasses other than TensorType subclasses from having non-scalar elements. The current logic assumes that ShapedType indicates a vector or tensor, which will not be true soon when MemRef subclasses ShapedType -- PiperOrigin-RevId: 250586364 --- mlir/lib/IR/StandardTypes.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 6c0d740..300cec9 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -147,11 +147,12 @@ int64_t ShapedType::getSizeInBits() const { if (elementType.isIntOrFloat()) return elementType.getIntOrFloatBitWidth() * getNumElements(); - // Tensors can have vectors and other tensors as elements, vectors cannot. - assert(!isa() && "unsupported vector element type"); - auto elementShapedType = elementType.dyn_cast(); - assert(elementShapedType && "unsupported tensor element type"); - return getNumElements() * elementShapedType.getSizeInBits(); + // Tensors can have vectors and other tensors as elements, other shaped types + // cannot. + assert(isa() && "unsupported element type"); + assert((elementType.isa() || elementType.isa()) && + "unsupported tensor element type"); + return getNumElements() * elementType.cast().getSizeInBits(); } ArrayRef ShapedType::getShape() const { -- 2.7.4