Add FuncOp::eraseArgument
authorSean Silva <silvasean@google.com>
Wed, 13 Nov 2019 18:59:24 +0000 (10:59 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 13 Nov 2019 18:59:55 +0000 (10:59 -0800)
This is a quite complex operation that users are likely to attempt to write
themselves and get wrong (citation: users=me).

Ideally, we could pull this into FunctionLike, but for now, the
FunctionType rewriting makes it FuncOp specific. We would need some hook
for rewriting the function type (which for LLVM's func op, would need to
rewrite the underlying LLVM type).

PiperOrigin-RevId: 280234164

mlir/include/mlir/IR/Function.h
mlir/lib/IR/Function.cpp
mlir/test/IR/test-func-erase-arg.mlir [new file with mode: 0644]
mlir/test/IR/test-func-set-type.mlir [new file with mode: 0644]
mlir/test/lib/IR/CMakeLists.txt
mlir/test/lib/IR/TestFunc.cpp [new file with mode: 0644]

index 228b030..83489f6 100644 (file)
@@ -66,6 +66,12 @@ public:
   void print(OpAsmPrinter &p);
   LogicalResult verify();
 
+  /// Erase a single argument at `argIndex`.
+  void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); }
+  /// Erases the arguments listed in `argIndices`.
+  /// `argIndices` is allowed to have duplicates and can be in any order.
+  void eraseArguments(ArrayRef<unsigned> argIndices);
+
   /// Returns the type of this function.
   FunctionType getType() {
     return getAttrOfType<TypeAttr>(getTypeAttrName())
@@ -77,10 +83,20 @@ public:
   /// operation and it is up to the caller to ensure that this is legal for this
   /// function, and to restore invariants:
   ///  - the entry block args must be updated to match the function params.
-  ///  - the arguments attributes may need an update: if the new type has less
-  ///    parameters we drop the extra attributes, if there are more parameters
-  ///    they won't have any attributes.
+  ///  - the argument/result attributes may need an update: if the new type has
+  ///  less parameters we drop the extra attributes, if there are more
+  ///  parameters they won't have any attributes.
   void setType(FunctionType newType) {
+    SmallVector<char, 16> nameBuf;
+    auto oldType = getType();
+    for (int i = newType.getNumInputs(), e = oldType.getNumInputs(); i < e;
+         i++) {
+      removeAttr(getArgAttrName(i, nameBuf));
+    }
+    for (int i = newType.getNumResults(), e = oldType.getNumResults(); i < e;
+         i++) {
+      removeAttr(getResultAttrName(i, nameBuf));
+    }
     setAttr(getTypeAttrName(), TypeAttr::get(newType));
   }
 
index 4f5a473..4e10350 100644 (file)
@@ -23,6 +23,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/Twine.h"
@@ -112,6 +113,40 @@ LogicalResult FuncOp::verify() {
   return success();
 }
 
+void FuncOp::eraseArguments(ArrayRef<unsigned> argIndices) {
+  auto oldType = getType();
+  int originalNumArgs = oldType.getNumInputs();
+  llvm::BitVector eraseIndices(originalNumArgs);
+  for (auto index : argIndices)
+    eraseIndices.set(index);
+  auto shouldEraseArg = [&](int i) { return eraseIndices.test(i); };
+
+  // There are 3 things that need to be updated:
+  // - Function type.
+  // - Arg attrs.
+  // - Block arguments of entry block.
+
+  // Update the function type and arg attrs.
+  SmallVector<Type, 4> newInputTypes;
+  SmallVector<NamedAttributeList, 4> newArgAttrs;
+  for (int i = 0; i < originalNumArgs; i++) {
+    if (shouldEraseArg(i))
+      continue;
+    newInputTypes.emplace_back(oldType.getInput(i));
+    newArgAttrs.emplace_back(getArgAttrDict(i));
+  }
+  setType(FunctionType::get(newInputTypes, oldType.getResults(), getContext()));
+  setAllArgAttrs(newArgAttrs);
+
+  // Update the entry block's arguments.
+  // We do this in reverse so that we erase later indices before earlier
+  // indices, to avoid shifting the later indices.
+  Block &entry = front();
+  for (int i = 0; i < originalNumArgs; i++)
+    if (shouldEraseArg(originalNumArgs - i - 1))
+      entry.eraseArgument(originalNumArgs - i - 1);
+}
+
 /// Add an entry block to an empty function, and set up the block arguments
 /// to match the signature of the function.
 Block *FuncOp::addEntryBlock() {
diff --git a/mlir/test/IR/test-func-erase-arg.mlir b/mlir/test/IR/test-func-erase-arg.mlir
new file mode 100644 (file)
index 0000000..2d6c71e
--- /dev/null
@@ -0,0 +1,76 @@
+// RUN: mlir-opt %s -test-func-erase-arg -split-input-file | FileCheck %s
+
+// CHECK: func @f()
+// CHECK-NOT: attributes{{.*}}arg
+func @f(%arg0: f32 {test.erase_this_arg}) {
+  return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A})
+// CHECK-NOT: attributes{{.*}}arg
+func @f(
+  %arg0: f32 {test.erase_this_arg},
+  %arg1: f32 {test.A}) {
+  return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A})
+// CHECK-NOT: attributes{{.*}}arg
+func @f(
+  %arg0: f32 {test.A},
+  %arg1: f32 {test.erase_this_arg}) {
+  return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B})
+// CHECK-NOT: attributes{{.*}}arg
+func @f(
+  %arg0: f32 {test.A},
+  %arg1: f32 {test.erase_this_arg},
+  %arg2: f32 {test.B}) {
+  return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B})
+// CHECK-NOT: attributes{{.*}}arg
+func @f(
+  %arg0: f32 {test.A},
+  %arg1: f32 {test.erase_this_arg},
+  %arg2: f32 {test.erase_this_arg},
+  %arg3: f32 {test.B}) {
+  return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: f32 {test.A}, %arg1: f32 {test.B}, %arg2: f32 {test.C})
+// CHECK-NOT: attributes{{.*}}arg
+func @f(
+  %arg0: f32 {test.A},
+  %arg1: f32 {test.erase_this_arg},
+  %arg2: f32 {test.B},
+  %arg3: f32 {test.erase_this_arg},
+  %arg4: f32 {test.C}) {
+  return
+}
+
+// -----
+
+// CHECK: func @f(%arg0: tensor<1xf32>, %arg1: tensor<2xf32>, %arg2: tensor<3xf32>)
+// CHECK-NOT: attributes{{.*}}arg
+func @f(
+  %arg0: tensor<1xf32>,
+  %arg1: f32 {test.erase_this_arg},
+  %arg2: tensor<2xf32>,
+  %arg3: f32 {test.erase_this_arg},
+  %arg4: tensor<3xf32>) {
+  return
+}
diff --git a/mlir/test/IR/test-func-set-type.mlir b/mlir/test/IR/test-func-set-type.mlir
new file mode 100644 (file)
index 0000000..0ec890e
--- /dev/null
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s -test-func-set-type -split-input-file | FileCheck %s --dump-input=fail
+
+// It's currently not possible to have an attribute with a function type due to
+// parser ambiguity. So instead we reference a function declaration to take the
+// type from.
+
+// -----
+
+// Test case: The setType call needs to erase some arg attrs.
+
+// CHECK: func @erase_arg(f32 {test.A})
+// CHECK-NOT: attributes{{.*arg[0-9]}}
+func @t(f32)
+func @erase_arg(%arg0: f32 {test.A}, %arg1: f32 {test.B})
+attributes {test.set_type_from = @t}
+
+// -----
+
+// Test case: The setType call needs to erase some result attrs.
+
+// CHECK: func @erase_result() -> (f32 {test.A})
+// CHECK-NOT: attributes{{.*result[0-9]}}
+func @t() -> (f32)
+func @erase_result() -> (f32 {test.A}, f32 {test.B})
+attributes {test.set_type_from = @t}
index c8ff810..9e3b8fb 100644 (file)
@@ -1,4 +1,5 @@
 add_llvm_library(MLIRTestIR
+  TestFunc.cpp
   TestSymbolUses.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp
new file mode 100644 (file)
index 0000000..880d078
--- /dev/null
@@ -0,0 +1,67 @@
+//===- TestFunctionLike.cpp - Pass to test helpers on FunctionLike --------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// This is a test pass for verifying FuncOp's eraseArgument method.
+struct TestFuncEraseArg : public ModulePass<TestFuncEraseArg> {
+  void runOnModule() override {
+    auto module = getModule();
+
+    for (FuncOp func : module.getOps<FuncOp>()) {
+      SmallVector<unsigned, 4> indicesToErase;
+      for (auto argIndex : llvm::seq<int>(0, func.getNumArguments())) {
+        if (func.getArgAttr(argIndex, "test.erase_this_arg")) {
+          // Push back twice to test that duplicate arg indices are handled
+          // correctly.
+          indicesToErase.push_back(argIndex);
+          indicesToErase.push_back(argIndex);
+        }
+      }
+      // Reverse the order to test that unsorted index lists are handled
+      // correctly.
+      std::reverse(indicesToErase.begin(), indicesToErase.end());
+      func.eraseArguments(indicesToErase);
+    }
+  }
+};
+
+/// This is a test pass for verifying FuncOp's setType method.
+struct TestFuncSetType : public ModulePass<TestFuncSetType> {
+  void runOnModule() override {
+    auto module = getModule();
+    SymbolTable symbolTable(module);
+
+    for (FuncOp func : module.getOps<FuncOp>()) {
+      auto sym = func.getAttrOfType<FlatSymbolRefAttr>("test.set_type_from");
+      if (!sym)
+        continue;
+      func.setType(symbolTable.lookup<FuncOp>(sym.getValue()).getType());
+    }
+  }
+};
+} // end anonymous namespace
+
+static PassRegistration<TestFuncEraseArg> pass("test-func-erase-arg",
+                                               "Test erasing func args.");
+
+static PassRegistration<TestFuncSetType> pass2("test-func-set-type",
+                                               "Test FuncOp::setType.");