From 91563c419e57e3e22de96122247ea27540b14659 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Thu, 22 Sep 2022 18:05:35 +0000 Subject: [PATCH] [mlir] Modify LinalgStructuredInterface to allow the computation block to have arguments only for a subset of operands. Summary: Currently there is an expectations that there is a corresponsing block argument for each operand. For some operation, it leads to unused arguments. For example, in `map`, only input operands are used for the computation. Differential Revision: https://reviews.llvm.org/D134444 --- mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td | 15 +++++++++++++++ mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp | 4 ++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index bf2509f..64c2bd1 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -570,6 +570,21 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> { /*methodBody=*/"", /*defaultImplementation=*/"" >, + InterfaceMethod< + /*desc=*/[{ + Return op operands that have a corresponding argument in the basic block. + By default, the block should have an argument for each operand, but there + are expection. For example, in `map` output operand isn't used in + the block. + }], + /*retTy=*/"OpOperandVector", + /*methodName=*/"getOpOperandsMatchingBBargs", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getInputAndOutputOperands(); + }] + >, //===------------------------------------------------------------------===// // Linalg generalization hooks. //===------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index c30f357..3172400 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -762,11 +762,11 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) { // not used). Block &block = linalgOp->getRegion(0).front(); - if (linalgOp.getNumInputsAndOutputs() != block.getNumArguments()) + if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments()) return op->emitOpError("expected as many non-induction variable region " "arguments as the number of input/output operands"); - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { + for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { Type elementType = getElementTypeOrSelf(opOperand->get()); Type argType = block.getArgument(opOperand->getOperandNumber()).getType(); if (elementType != argType) -- 2.7.4