//===----------------------------------------------------------------------===//
// Operator: equal
//===----------------------------------------------------------------------===//
-def Tosa_EqualOp : Tosa_Op<"equal", [
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>,
- ResultsBroadcastableShape, Commutative, NoSideEffect]> {
+def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
+ Commutative, NoSideEffect]> {
let summary = "Returns the truth value of (x == y) element-wise.";
let description = [{
let results = (outs
I1Tensor:$output
);
+
+ let extraClassDeclaration = [{
+ /// Returns when two result types are compatible for this op; method used by
+ /// InferTypeOpInterface.
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
+ }];
}
//===----------------------------------------------------------------------===//
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
}
}
+static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
+ SmallVector<int64_t> &outShape) {
+ int64_t outRank = 0;
+ for (int i = 0, e = operands.size(); i != e; ++i) {
+ auto shape = operands.getShape(i);
+ if (!shape.hasRank()) {
+ // TODO(jennik): Update function to have better case handling for invalid
+ // operands and for ranked tensors.
+ return failure();
+ }
+ outRank = std::max<int64_t>(outRank, shape.getRank());
+ }
+
+ outShape.resize(outRank, 1);
+
+ for (int i = 0, e = operands.size(); i != e; ++i) {
+ auto shape = operands.getShape(i);
+ auto rankDiff = outShape.size() - shape.getRank();
+
+ for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
+ auto dim1 = outShape[i + rankDiff];
+ auto dim2 = shape.getDimSize(i);
+ auto resolvedDim = dim1;
+
+ if (dim1 == 1) {
+ resolvedDim = dim2;
+ } else if (dim2 == 1) {
+ resolvedDim = dim1;
+ } else if (dim1 != dim2) {
+ return failure();
+ }
+ outShape[i + rankDiff] = resolvedDim;
+ }
+ }
+
+ return success();
+}
+
LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
MLIRContext *context, ::llvm::Optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
return success();
}
+LogicalResult tosa::EqualOp::inferReturnTypeComponents(
+ MLIRContext *context, ::llvm::Optional<Location> location,
+ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ llvm::SmallVector<int64_t> outShape;
+ if (resolveBroadcastShape(operands, outShape).failed()) {
+ inferredReturnShapes.push_back(ShapedTypeComponents());
+ return success();
+ }
+
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(outShape, IntegerType::get(context, /*width=*/1)));
+ return success();
+}
+
+bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ if (l.size() != r.size() || l.size() != 1)
+ return false;
+ return succeeded(verifyCompatibleShape(l[0], r[0]));
+}
+
LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
MLIRContext *context, ::llvm::Optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
#undef REDUCE_SHAPE_INFER
-static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
- SmallVector<int64_t> &outShape) {
- int64_t outRank = 0;
- for (int i = 0, e = operands.size(); i != e; ++i) {
- auto shape = operands.getShape(i);
- if (!shape.hasRank()) {
- return failure();
- }
- outRank = std::max<int64_t>(outRank, shape.getRank());
- }
-
- outShape.resize(outRank, 1);
-
- for (int i = 0, e = operands.size(); i != e; ++i) {
- auto shape = operands.getShape(i);
- auto rankDiff = outShape.size() - shape.getRank();
-
- for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
- auto dim1 = outShape[i + rankDiff];
- auto dim2 = shape.getDimSize(i);
- auto resolvedDim = dim1;
-
- if (dim1 == 1) {
- resolvedDim = dim2;
- } else if (dim2 == 1) {
- resolvedDim = dim1;
- } else if (dim1 != dim2) {
- return failure();
- }
- outShape[i + rankDiff] = resolvedDim;
- }
- }
-
- return success();
-}
-
static LogicalResult NAryInferReturnTypes(
const ValueShapeRange &operands,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
NARY_SHAPE_INFER(tosa::ClampOp)
NARY_SHAPE_INFER(tosa::ClzOp)
NARY_SHAPE_INFER(tosa::DivOp)
-NARY_SHAPE_INFER(tosa::EqualOp)
NARY_SHAPE_INFER(tosa::ExpOp)
NARY_SHAPE_INFER(tosa::FloorOp)
NARY_SHAPE_INFER(tosa::GreaterEqualOp)