[Shape] Simplify getShapeVec a bit. NFCI.
authorBenjamin Kramer <benny.kra@googlemail.com>
Sun, 13 Feb 2022 15:57:48 +0000 (16:57 +0100)
committerBenjamin Kramer <benny.kra@googlemail.com>
Sun, 13 Feb 2022 15:58:16 +0000 (16:58 +0100)
mlir/lib/Dialect/Shape/IR/Shape.cpp

index 89a9766..5c851f5 100644 (file)
@@ -47,19 +47,15 @@ bool shape::isExtentTensorType(Type type) {
 LogicalResult shape::getShapeVec(Value input,
                                  SmallVectorImpl<int64_t> &shapeValues) {
   if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
-    auto type = inputOp.getArg().getType().dyn_cast<ShapedType>();
+    auto type = inputOp.getArg().getType().cast<ShapedType>();
     if (!type.hasRank())
       return failure();
-    shapeValues = llvm::to_vector<6>(type.getShape());
+    llvm::append_range(shapeValues, type.getShape());
     return success();
   }
-  if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
-    shapeValues = llvm::to_vector<6>(inputOp.getShape().getValues<int64_t>());
-    return success();
-  }
-  if (auto inputOp = input.getDefiningOp<arith::ConstantOp>()) {
-    shapeValues = llvm::to_vector<6>(
-        inputOp.getValue().cast<DenseIntElementsAttr>().getValues<int64_t>());
+  DenseIntElementsAttr attr;
+  if (matchPattern(input, m_Constant(&attr))) {
+    llvm::append_range(shapeValues, attr.getValues<int64_t>());
     return success();
   }
   return failure();