From 22a8bc6ec393dcb2309d171f0a849105959e5c90 Mon Sep 17 00:00:00 2001 From: Geoffrey Martin-Noble Date: Sat, 18 May 2019 05:31:35 -0700 Subject: [PATCH] Make shape matching work for any shaped type. The current implementation makes some assumptions about what can be a shaped type, which aren't really necessary. It also has strange behavior for types that aren't in the limited set it handles (e.g. dialect-defined types) Updated the comment to match the implementation. This is partially motivated by the desire to make MemRef a subclass of ShapedType -- PiperOrigin-RevId: 248859674 --- mlir/lib/IR/Operation.cpp | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index bc8947c..df6a6cf 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -767,24 +767,21 @@ LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op, } /// Returns success if the given two types have the same shape. That is, -/// they are both scalars, or they are both static shaped types with the same -/// dimension specifications. The element type does not matter. +/// they are both scalars (not shaped), or they are both shaped types and at +/// least one is unranked or they have the same shape. The element type does not +/// matter. static LogicalResult verifyShapeMatch(Type type1, Type type2) { - // Check scalar cases - if (type1.isIntOrIndexOrFloat()) - return success(type2.isIntOrIndexOrFloat()); + auto sType1 = type1.dyn_cast(); + auto sType2 = type2.dyn_cast(); - // Check unranked tensor cases - if (type1.isa() || type2.isa()) - return success(); + // Either both or neither type should be shaped. + if (!sType1) + return success(!sType2); - // Check normal vector/tensor cases - if (auto sType1 = type1.dyn_cast()) { - auto sType2 = type2.dyn_cast(); - return success(sType2 && sType1.getShape() == sType2.getShape()); - } + if (sType1.getRank() == -1 || sType2.getRank() == -1) + return success(); - return success(); + return success(sType1.getShape() == sType2.getShape()); } LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { -- 2.7.4