#define MLIR_ANALYSIS_INFERTYPEOPINTERFACE_H_
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"
"SmallVectorImpl<ShapedTypeComponents>&":
$inferedReturnShapes)
>,
+ InterfaceMethod<
+ /*desc=*/[{Reify the shape computation for the operation.
+
+ Insert operations using the given OpBulder that computes the result shape.
+ }],
+ /*retTy=*/"LogicalResult",
+ /*methodName=*/"reifyReturnTypeShapes",
+ /*args=*/(ins "OpBuilder&":$builder,
+ "SmallVectorImpl<Value>&":$reifiedReturnShapes),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{ return failure(); }]
+ >,
];
}
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
- SmallVectorImpl<ShapedTypeComponents> &inferedComponents) {
- // Create return type consisting of the first element of each shape of the
- // input operands or unknown for unranked operand.
- std::vector<int64_t> shape;
- shape.reserve(operands.size());
- for (auto operandType : operands.getTypes()) {
- if (auto sval = operandType.dyn_cast<ShapedType>()) {
- if (sval.hasRank())
- shape.push_back(sval.getShape().front());
- else
- shape.push_back(ShapedType::kDynamicSize);
- } else {
- return emitOptionalError(location, "only shaped type operands allowed");
- }
+ SmallVectorImpl<ShapedTypeComponents> &inferedReturnShapes) {
+ // Create return type consisting of the last element of the first operand.
+ auto operandType = *operands.getTypes().begin();
+ auto sval = operandType.dyn_cast<ShapedType>();
+ if (!sval) {
+ return emitOptionalError(location, "only shaped type operands allowed");
}
- inferedComponents.reserve(1);
+ int64_t dim =
+ sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
auto type = IntegerType::get(17, context);
- inferedComponents.emplace_back(shape, type);
+ inferedReturnShapes.push_back(ShapedTypeComponents({dim}, type));
+ return success();
+}
+
+LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
+ OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
+ shapes = SmallVector<Value, 1>{
+ builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)};
return success();
}
}
}
+static void reifyReturnShape(Operation *op) {
+ OpBuilder b(op);
+
+ // Use permutations of 2 args as operands.
+ auto shapedOp = cast<OpWithShapedTypeInferTypeInterfaceOp>(op);
+ SmallVector<Value, 2> shapes;
+ if (failed(shapedOp.reifyReturnTypeShapes(b, shapes)))
+ return;
+ for (auto it : llvm::enumerate(shapes))
+ op->emitRemark() << "value " << it.index() << ": "
+ << it.value().getDefiningOp();
+}
+
struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
void runOnFunction() override {
if (getFunction().getName() == "testCreateFunctions") {
};
return;
}
+ if (getFunction().getName() == "testReifyFunctions") {
+ std::vector<Operation *> ops;
+ // Collect ops to avoid triggering on inserted ops.
+ for (auto &op : getFunction().getBody().front())
+ if (isa<OpWithShapedTypeInferTypeInterfaceOp>(op))
+ ops.push_back(&op);
+ // Generate test patterns for each, but skip terminator.
+ for (auto *op : ops)
+ reifyReturnShape(op);
+ }
}
};
} // end anonymous namespace
// CHECK: "test.no_attributes"
%good = "test.no_attributes"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
// CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10x10xi17>
+// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi17>
// CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10x20xi17>
+// CHECK-SAME: (tensor<10xf32>, tensor<20xi32>) -> tensor<10xi17>
// CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20x10xi17>
+// CHECK-SAME: (tensor<20xi32>, tensor<10xf32>) -> tensor<20xi17>
// CHECK: "test.op_with_shaped_type_infer_type_if"
-// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20x20xi17>
+// CHECK-SAME: (tensor<20xi32>, tensor<20xi32>) -> tensor<20xi17>
// CHECK: "test.op_with_infer_type_if"
// CHECK-SAME: (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
// CHECK: "test.op_with_infer_type_if"
%bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: testReifyFunctions
+func @testReifyFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
+ // expected-remark@+1 {{constant 10}}
+ %0 = "test.op_with_shaped_type_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xi17>
+ // expected-remark@+1 {{constant 20}}
+ %1 = "test.op_with_shaped_type_infer_type_if"(%arg1, %arg0) : (tensor<20xf32>, tensor<10xf32>) -> tensor<20xi17>
+ return
+}