void print(OpAsmPrinter &p);
LogicalResult verify();
+ /// Erase a single argument at `argIndex`.
+ void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
+ /// Erases the arguments listed in `argIndices`.
+ /// `argIndices` is allowed to have duplicates and can be in any order.
+ void eraseArguments(ArrayRef<unsigned> argIndices);
+
/// Returns the type of this function.
FunctionType getType() {
return getAttrOfType<TypeAttr>(getTypeAttrName())
/// operation and it is up to the caller to ensure that this is legal for this
/// function, and to restore invariants:
/// - the entry block args must be updated to match the function params.
- /// - the arguments attributes may need an update: if the new type has less
- /// parameters we drop the extra attributes, if there are more parameters
- /// they won't have any attributes.
+ /// - the argument/result attributes may need an update: if the new type has
+ /// less parameters we drop the extra attributes, if there are more
+ /// parameters they won't have any attributes.
void setType(FunctionType newType) {
+ SmallVector<char, 16> nameBuf;
+ auto oldType = getType();
+ for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e;
+ i++) {
+ removeAttr(getArgAttrName(i, nameBuf));
+ }
+ for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e;
+ i++) {
+ removeAttr(getResultAttrName(i, nameBuf));
+ }
setAttr(getTypeAttrName(), TypeAttr::get(newType));
}
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/Twine.h"
return success();
}
+void FuncOp::eraseArguments(ArrayRef<unsigned> argIndices) {
+ auto oldType = getType();
+ int originalNumArgs = oldType.getNumInputs();
+ llvm::BitVector eraseIndices(originalNumArgs);
+ for (auto index : argIndices)
+ eraseIndices.set(index);
+ auto shouldEraseArg = [&](int i) { return eraseIndices.test(i); };
+
+ // There are 3 things that need to be updated:
+ // - Function type.
+ // - Arg attrs.
+ // - Block arguments of entry block.
+
+ // Update the function type and arg attrs.
+ SmallVector<Type, 4> newInputTypes;
+ SmallVector<NamedAttributeList, 4> newArgAttrs;
+ for (int i = 0; i < originalNumArgs; i++) {
+ if (shouldEraseArg(i))
+ continue;
+ newInputTypes.emplace_back(oldType.getInput(i));
+ newArgAttrs.emplace_back(getArgAttrDict(i));
+ }
+ setType(FunctionType::get(newInputTypes, oldType.getResults(), getContext()));
+ setAllArgAttrs(newArgAttrs);
+
+ // Update the entry block's arguments.
+ // We do this in reverse so that we erase later indices before earlier
+ // indices, to avoid shifting the later indices.
+ Block &entry = front();
+ for (int i = 0; i < originalNumArgs; i++)
+ if (shouldEraseArg(originalNumArgs - i - 1))
+ entry.eraseArgument(originalNumArgs - i - 1);
+}
+
/// Add an entry block to an empty function, and set up the block arguments
/// to match the signature of the function.
Block *FuncOp::addEntryBlock() {
--- /dev/null
+// RUN: mlir-opt %s -test-func-erase-arg -split-input-file | FileCheck %s
+
+// CHECK: func @f()
+// CHECK-NOT: attributes{{.*}}arg
+func @f(%arg0: f32 {test.erase_this_arg}) {
+ return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A})
+// CHECK-NOT: attributes{{.*}}arg
+func @f(
+ %arg0: f32 {test.erase_this_arg},
+ %arg1: f32 {test.A}) {
+ return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A})
+// CHECK-NOT: attributes{{.*}}arg
+func @f(
+ %arg0: f32 {test.A},
+ %arg1: f32 {test.erase_this_arg}) {
+ return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B})
+// CHECK-NOT: attributes{{.*}}arg
+func @f(
+ %arg0: f32 {test.A},
+ %arg1: f32 {test.erase_this_arg},
+ %arg2: f32 {test.B}) {
+ return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B})
+// CHECK-NOT: attributes{{.*}}arg
+func @f(
+ %arg0: f32 {test.A},
+ %arg1: f32 {test.erase_this_arg},
+ %arg2: f32 {test.erase_this_arg},
+ %arg3: f32 {test.B}) {
+ return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B}, %arg2: f32 {test.C})
+// CHECK-NOT: attributes{{.*}}arg
+func @f(
+ %arg0: f32 {test.A},
+ %arg1: f32 {test.erase_this_arg},
+ %arg2: f32 {test.B},
+ %arg3: f32 {test.erase_this_arg},
+ %arg4: f32 {test.C}) {
+ return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>, %arg2: tensor<3xf32>)
+// CHECK-NOT: attributes{{.*}}arg
+func @f(
+ %arg0: tensor<1xf32>,
+ %arg1: f32 {test.erase_this_arg},
+ %arg2: tensor<2xf32>,
+ %arg3: f32 {test.erase_this_arg},
+ %arg4: tensor<3xf32>) {
+ return
+}
--- /dev/null
+// RUN: mlir-opt %s -test-func-set-type -split-input-file | FileCheck %s --dump-input=fail
+
+// It's currently not possible to have an attribute with a function type due to
+// parser ambiguity. So instead we reference a function declaration to take the
+// type from.
+
+// -----
+
+// Test case: The setType call needs to erase some arg attrs.
+
+// CHECK: func @erase_arg(f32 {test.A})
+// CHECK-NOT: attributes{{.*arg[0-9]}}
+func @t(f32)
+func @erase_arg(%arg0: f32 {test.A}, %arg1: f32 {test.B})
+attributes {test.set_type_from = @t}
+
+// -----
+
+// Test case: The setType call needs to erase some result attrs.
+
+// CHECK: func @erase_result() -> (f32 {test.A})
+// CHECK-NOT: attributes{{.*result[0-9]}}
+func @t() -> (f32)
+func @erase_result() -> (f32 {test.A}, f32 {test.B})
+attributes {test.set_type_from = @t}
add_llvm_library(MLIRTestIR
+ TestFunc.cpp
TestSymbolUses.cpp
ADDITIONAL_HEADER_DIRS
--- /dev/null
+//===- TestFunctionLike.cpp - Pass to test helpers on FunctionLike --------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// This is a test pass for verifying FuncOp's eraseArgument method.
+struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> {
+ void runOnModule() override {
+ auto module = getModule();
+
+ for (FuncOp func : module.getOps<FuncOp>()) {
+ SmallVector<unsigned, 4> indicesToErase;
+ for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) {
+ if (func.getArgAttr(argIndex, "test.erase_this_arg")) {
+ // Push back twice to test that duplicate arg indices are handled
+ // correctly.
+ indicesToErase.push_back(argIndex);
+ indicesToErase.push_back(argIndex);
+ }
+ }
+ // Reverse the order to test that unsorted index lists are handled
+ // correctly.
+ std::reverse(indicesToErase.begin(), indicesToErase.end());
+ func.eraseArguments(indicesToErase);
+ }
+ }
+};
+
+/// This is a test pass for verifying FuncOp's setType method.
+struct TestFuncSetType : public ModulePass<TestFuncSetType> {
+ void runOnModule() override {
+ auto module = getModule();
+ SymbolTable symbolTable(module);
+
+ for (FuncOp func : module.getOps<FuncOp>()) {
+ auto sym = func.getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
+ if (!sym)
+ continue;
+ func.setType(symbolTable.lookup<FuncOp>(sym.getValue()).getType());
+ }
+ }
+};
+} // end anonymous namespace
+
+static PassRegistration<TestFuncEraseArg> pass("test-func-erase-arg",
+ "Test erasing func args.");
+
+static PassRegistration<TestFuncSetType> pass2("test-func-set-type",
+ "Test FuncOp::setType.");