[MLIR][SPIRV] Start module combiner.
authorergawy <kareem.ergawy@gmail.com>
Fri, 30 Oct 2020 18:36:19 +0000 (14:36 -0400)
committerLei Zhang <antiagainst@google.com>
Fri, 30 Oct 2020 20:55:43 +0000 (16:55 -0400)
This commit adds a new library that merges/combines a number of spv
modules into a combined one. The library has a single entry point:
combine(...).

To combine a number of MLIR spv modules, we move all the module-level ops
from all the input modules into one big combined module. To that end, the
combination process can proceed in 2 phases:

  (1) resolving conflicts between pairs of ops from different modules
  (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)

This patch implements only the first phase.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D90477

mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h [new file with mode: 0644]
mlir/lib/Dialect/SPIRV/CMakeLists.txt
mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp [new file with mode: 0644]
mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir [new file with mode: 0644]
mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/SPIRV/CMakeLists.txt
mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/mlir-opt.cpp

diff --git a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h
new file mode 100644 (file)
index 0000000..b7ecd57
--- /dev/null
@@ -0,0 +1,69 @@
+//===- ModuleCombiner.h - MLIR SPIR-V Module Combiner -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the entry point to the SPIR-V module combiner library.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
+#define MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
+
+#include "mlir/Dialect/SPIRV/SPIRVModule.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace mlir {
+class OpBuilder;
+
+namespace spirv {
+class ModuleOp;
+
+/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops
+/// from all the input modules into one big combined module. To that end, the
+/// combination process proceeds in 2 phases:
+///
+///   (1) resolve conflicts between pairs of ops from different modules
+///   (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO)
+///
+/// For the conflict resolution phase, the following rules are employed to
+/// resolve such conflicts:
+///
+///   - If 2 spv.func's have the same symbol name, then rename one of the
+///   functions.
+///   - If an spv.func and another op have the same symbol name, then rename the
+///   other symbol.
+///   - If none of the 2 conflicting ops are spv.func, then rename either.
+///
+/// In all cases, the references to the updated symbol are also updated to
+/// reflect the change.
+///
+/// \param modules the list of modules to combine. Input modules are not
+/// modified.
+/// \param combinedMdouleBuilder an OpBuilder to be used for
+/// building up the combined module.
+/// \param symbRenameListener a listener that gets called everytime a symbol in
+///                           one of the input modules is renamed. The arguments
+///                           passed to the listener are: the input
+///                           spirv::ModuleOp that contains the renamed symbol,
+///                           a StringRef to the old symbol name, and a
+///                           StringRef to the new symbol name. Note that it is
+///                           the responsibility of the caller to properly
+///                           retain the storage underlying the passed
+///                           StringRefs if the listener callback outlives this
+///                           function call.
+///
+/// \return the combined module.
+OwningSPIRVModuleRef
+combine(llvm::MutableArrayRef<ModuleOp> modules,
+        OpBuilder &combinedModuleBuilder,
+        llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
+            symbRenameListener);
+} // namespace spirv
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_MODULECOMBINER_H_
index 10f06fd..f371821 100644 (file)
@@ -34,5 +34,6 @@ add_mlir_dialect_library(MLIRSPIRV
   MLIRTransforms
   )
 
+add_subdirectory(Linking)
 add_subdirectory(Serialization)
 add_subdirectory(Transforms)
diff --git a/mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt
new file mode 100644 (file)
index 0000000..4cc0168
--- /dev/null
@@ -0,0 +1 @@
+add_subdirectory(ModuleCombiner)
diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt
new file mode 100644 (file)
index 0000000..69af5a6
--- /dev/null
@@ -0,0 +1,11 @@
+add_mlir_dialect_library(MLIRSPIRVModuleCombiner
+  ModuleCombiner.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRSPIRV
+  MLIRSupport
+  )
diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp
new file mode 100644 (file)
index 0000000..7687ab2
--- /dev/null
@@ -0,0 +1,181 @@
+//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the the SPIR-V module combiner library.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/ModuleCombiner.h"
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringExtras.h"
+
+using namespace mlir;
+
+static constexpr unsigned maxFreeID = 1 << 20;
+
+static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
+                                    spirv::ModuleOp combinedModule) {
+  SmallString<64> newSymName(oldSymName);
+  newSymName.push_back('_');
+
+  while (lastUsedID < maxFreeID) {
+    std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
+
+    if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
+      newSymName += llvm::utostr(lastUsedID);
+      break;
+    }
+  }
+
+  return newSymName;
+}
+
+/// Check if a symbol with the same name as op already exists in source. If so,
+/// rename op and update all its references in target.
+static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
+                                            spirv::ModuleOp target,
+                                            spirv::ModuleOp source,
+                                            unsigned &lastUsedID) {
+  if (!SymbolTable::lookupSymbolIn(source, op.getName()))
+    return success();
+
+  StringRef oldSymName = op.getName();
+  SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target);
+
+  if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
+    return op.emitError("unable to update all symbol uses for ")
+           << oldSymName << " to " << newSymName;
+
+  SymbolTable::setSymbolName(op, newSymName);
+  return success();
+}
+
+namespace mlir {
+namespace spirv {
+
+// TODO Properly test symbol rename listener mechanism.
+
+OwningSPIRVModuleRef
+combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
+        OpBuilder &combinedModuleBuilder,
+        llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
+            symRenameListener) {
+  unsigned lastUsedID = 0;
+
+  if (modules.empty())
+    return nullptr;
+
+  auto addressingModel = modules[0].addressing_model();
+  auto memoryModel = modules[0].memory_model();
+
+  auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
+      modules[0].getLoc(), addressingModel, memoryModel);
+  combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());
+
+  // In some cases, a symbol in the (current state of the) combined module is
+  // renamed in order to maintain the conflicting symbol in the input module
+  // being merged. For example, if the conflict is between a global variable in
+  // the current combined module and a function in the input module, the global
+  // varaible is renamed. In order to notify listeners of the symbol updates in
+  // such cases, we need to keep track of the module from which the renamed
+  // symbol in the combined module originated. This map keeps such information.
+  DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;
+
+  for (auto module : modules) {
+    if (module.addressing_model() != addressingModel ||
+        module.memory_model() != memoryModel) {
+      module.emitError(
+          "input modules differ in addressing model and/or memory model");
+      return nullptr;
+    }
+
+    spirv::ModuleOp moduleClone = module.clone();
+
+    // In the combined module, rename all symbols that conflict with symbols
+    // from the current input module. This renmaing applies to all ops except
+    // for spv.funcs. This way, if the conflicting op in the input module is
+    // non-spv.func, we rename that symbol instead and maintain the spv.func in
+    // the combined module name as it is.
+    for (auto &op : combinedModule.getBlock().without_terminator()) {
+      if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
+        StringRef oldSymName = symbolOp.getName();
+
+        if (!isa<FuncOp>(op) &&
+            failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
+                                          lastUsedID)))
+          return nullptr;
+
+        StringRef newSymName = symbolOp.getName();
+
+        if (symRenameListener && oldSymName != newSymName) {
+          spirv::ModuleOp originalModule =
+              symNameToModuleMap.lookup(oldSymName);
+
+          if (!originalModule) {
+            module.emitError("unable to find original ModuleOp for symbol ")
+                << oldSymName;
+            return nullptr;
+          }
+
+          symRenameListener(originalModule, oldSymName, newSymName);
+
+          // Since the symbol name is updated, there is no need to maintain the
+          // entry that assocaites the old symbol name with the original module.
+          symNameToModuleMap.erase(oldSymName);
+          // Instead, add a new entry to map the new symbol name to the original
+          // module in case it gets renamed again later.
+          symNameToModuleMap[newSymName] = originalModule;
+        }
+      }
+    }
+
+    // In the current input module, rename all symbols that conflict with
+    // symbols from the combined module. This includes renaming spv.funcs.
+    for (auto &op : moduleClone.getBlock().without_terminator()) {
+      if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
+        StringRef oldSymName = symbolOp.getName();
+
+        if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
+                                          lastUsedID)))
+          return nullptr;
+
+        StringRef newSymName = symbolOp.getName();
+
+        if (symRenameListener && oldSymName != newSymName) {
+          symRenameListener(module, oldSymName, newSymName);
+
+          // Insert the module associated with the symbol name.
+          auto emplaceResult =
+              symNameToModuleMap.try_emplace(symbolOp.getName(), module);
+
+          // If an entry with the same symbol name is already present, this must
+          // be a problem with the implementation, specially clean-up of the map
+          // while iterating over the combined module above.
+          if (!emplaceResult.second) {
+            module.emitError("did not expect to find an entry for symbol ")
+                << symbolOp.getName();
+            return nullptr;
+          }
+        }
+      }
+    }
+
+    // Clone all the module's ops to the combined module.
+    for (auto &op : moduleClone.getBlock().without_terminator())
+      combinedModuleBuilder.insert(op.clone());
+  }
+
+  return combinedModule;
+}
+
+} // namespace spirv
+} // namespace mlir
diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir
new file mode 100644 (file)
index 0000000..07fd41e
--- /dev/null
@@ -0,0 +1,50 @@
+// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.specConstant @m1_sc
+// CHECK-NEXT:     spv.specConstant @m2_sc
+// CHECK-NEXT:     spv.func @variable_init_spec_constant
+// CHECK-NEXT:       spv._reference_of @m2_sc
+// CHECK-NEXT:       spv.Variable init
+// CHECK-NEXT:       spv.Return
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.specConstant @m1_sc = 42.42 : f32
+}
+
+spv.module Logical GLSL450 {
+  spv.specConstant @m2_sc = 42 : i32
+  spv.func @variable_init_spec_constant() -> () "None" {
+    %0 = spv._reference_of @m2_sc : i32
+    %1 = spv.Variable init(%0) : !spv.ptr<i32, Function>
+    spv.Return
+  }
+}
+}
+
+// -----
+
+module {
+spv.module Physical64 GLSL450 {
+}
+
+// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
+spv.module Logical GLSL450 {
+}
+}
+
+// -----
+
+module {
+spv.module Logical Simple {
+}
+
+// expected-error @+1 {{input modules differ in addressing model and/or memory model}}
+spv.module Logical GLSL450 {
+}
+}
diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir
new file mode 100644 (file)
index 0000000..f5535c4
--- /dev/null
@@ -0,0 +1,682 @@
+// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// Test basic renaming of conflicting funcOps.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @foo_1
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : f32) -> f32 "None" {
+    spv.ReturnValue %arg0 : f32
+  }
+}
+}
+
+// -----
+
+// Test basic renaming of conflicting funcOps across 3 modules.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @foo_1
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @foo_2
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : f32) -> f32 "None" {
+    spv.ReturnValue %arg0 : f32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+}
+
+// -----
+
+// Test properly updating references to a renamed funcOp.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @foo_1
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @bar
+// CHECK-NEXT:       spv.FunctionCall @foo_1
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : f32) -> f32 "None" {
+    spv.ReturnValue %arg0 : f32
+  }
+
+  spv.func @bar(%arg0 : f32) -> f32 "None" {
+    %0 = spv.FunctionCall @foo(%arg0) : (f32) ->  (f32)
+    spv.ReturnValue %0 : f32
+  }
+}
+}
+
+// -----
+
+// Test properly updating references to a renamed funcOp if the functionCallOp
+// preceeds the callee funcOp definition.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @bar
+// CHECK-NEXT:       spv.FunctionCall @foo_1
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @foo_1
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.func @bar(%arg0 : f32) -> f32 "None" {
+    %0 = spv.FunctionCall @foo(%arg0) : (f32) ->  (f32)
+    spv.ReturnValue %0 : f32
+  }
+
+  spv.func @foo(%arg0 : f32) -> f32 "None" {
+    spv.ReturnValue %arg0 : f32
+  }
+}
+}
+
+// -----
+
+// Test properly updating entryPointOp and executionModeOp attached to renamed
+// funcOp.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @foo_1
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.EntryPoint "GLCompute" @foo_1
+// CHECK-NEXT:     spv.ExecutionMode @foo_1 "ContractionOff"
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : f32) -> f32 "None" {
+    spv.ReturnValue %arg0 : f32
+  }
+
+  spv.EntryPoint "GLCompute" @foo
+  spv.ExecutionMode @foo "ContractionOff"
+}
+}
+
+// -----
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.EntryPoint "GLCompute" @fo
+// CHECK-NEXT:     spv.ExecutionMode @foo "ContractionOff"
+
+// CHECK-NEXT:     spv.func @foo_1
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.EntryPoint "GLCompute" @foo_1
+// CHECK-NEXT:     spv.ExecutionMode @foo_1 "ContractionOff"
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+  
+  spv.EntryPoint "GLCompute" @foo
+  spv.ExecutionMode @foo "ContractionOff"
+}
+
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : f32) -> f32 "None" {
+    spv.ReturnValue %arg0 : f32
+  }
+
+  spv.EntryPoint "GLCompute" @foo
+  spv.ExecutionMode @foo "ContractionOff"
+}
+}
+
+// -----
+
+// Resolve conflicting funcOp and globalVariableOp.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.globalVariable @foo_1
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+}
+
+// -----
+
+// Resolve conflicting funcOp and globalVariableOp and update the global variable's
+// references.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.globalVariable @foo_1
+// CHECK-NEXT:     spv.func @bar
+// CHECK-NEXT:       spv._address_of @foo_1
+// CHECK-NEXT:       spv.Load
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+
+  spv.func @bar() -> f32 "None" {
+    %0 = spv._address_of @foo : !spv.ptr<f32, Input>
+    %1 = spv.Load "Input" %0 : f32
+    spv.ReturnValue %1 : f32
+  }
+}
+}
+
+// -----
+
+// Resolve conflicting globalVariableOp and funcOp and update the global variable's
+// references.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.globalVariable @foo_1
+// CHECK-NEXT:     spv.func @bar
+// CHECK-NEXT:       spv._address_of @foo_1
+// CHECK-NEXT:       spv.Load
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+
+  spv.func @bar() -> f32 "None" {
+    %0 = spv._address_of @foo : !spv.ptr<f32, Input>
+    %1 = spv.Load "Input" %0 : f32
+    spv.ReturnValue %1 : f32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+}
+
+// -----
+
+// Resolve conflicting funcOp and specConstantOp.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.specConstant @foo_1
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.specConstant @foo = -5 : i32
+}
+}
+
+// -----
+
+// Resolve conflicting funcOp and specConstantOp and update the spec constant's
+// references.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.specConstant @foo_1
+// CHECK-NEXT:     spv.func @bar
+// CHECK-NEXT:       spv._reference_of @foo_1
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.specConstant @foo = -5 : i32
+
+  spv.func @bar() -> i32 "None" {
+    %0 = spv._reference_of @foo : i32 
+    spv.ReturnValue %0 : i32
+  }
+}
+}
+
+// -----
+
+// Resolve conflicting specConstantOp and funcOp and update the spec constant's
+// references.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.specConstant @foo_1
+// CHECK-NEXT:     spv.func @bar
+// CHECK-NEXT:       spv._reference_of @foo_1
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.specConstant @foo = -5 : i32
+
+  spv.func @bar() -> i32 "None" {
+    %0 = spv._reference_of @foo : i32
+    spv.ReturnValue %0 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+}
+
+// -----
+
+// Resolve conflicting funcOp and specConstantCompositeOp.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.specConstant @bar
+// CHECK-NEXT:     spv.specConstantComposite @foo_1 (@bar, @bar)
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.specConstant @bar = -5 : i32
+  spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
+}
+}
+
+// -----
+
+// Resolve conflicting funcOp and specConstantCompositeOp and update the spec
+// constant's references.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.specConstant @bar
+// CHECK-NEXT:     spv.specConstantComposite @foo_1 (@bar, @bar)
+// CHECK-NEXT:     spv.func @baz
+// CHECK-NEXT:       spv._reference_of @foo_1
+// CHECK-NEXT:       spv.CompositeExtract
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.specConstant @bar = -5 : i32
+  spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
+
+  spv.func @baz() -> i32 "None" {
+    %0 = spv._reference_of @foo : !spv.array<2 x i32>
+    %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
+    spv.ReturnValue %1 : i32
+  }
+}
+}
+
+// -----
+
+// Resolve conflicting specConstantCompositeOp and funcOp and update the spec
+// constant's references.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.specConstant @bar
+// CHECK-NEXT:     spv.specConstantComposite @foo_1 (@bar, @bar)
+// CHECK-NEXT:     spv.func @baz
+// CHECK-NEXT:       spv._reference_of @foo_1
+// CHECK-NEXT:       spv.CompositeExtract
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.specConstant @bar = -5 : i32
+  spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
+
+  spv.func @baz() -> i32 "None" {
+    %0 = spv._reference_of @foo : !spv.array<2 x i32>
+    %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
+    spv.ReturnValue %1 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+}
+}
+
+// -----
+
+// Resolve conflicting spec constants and funcOps and update the spec constant's
+// references.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.specConstant @bar_1
+// CHECK-NEXT:     spv.specConstantComposite @foo_2 (@bar_1, @bar_1)
+// CHECK-NEXT:     spv.func @baz
+// CHECK-NEXT:       spv._reference_of @foo_2
+// CHECK-NEXT:       spv.CompositeExtract
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @foo
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+
+// CHECK-NEXT:     spv.func @bar
+// CHECK-NEXT:       spv.ReturnValue
+// CHECK-NEXT:     }
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.specConstant @bar = -5 : i32
+  spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
+
+  spv.func @baz() -> i32 "None" {
+    %0 = spv._reference_of @foo : !spv.array<2 x i32>
+    %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
+    spv.ReturnValue %1 : i32
+  }
+}
+
+spv.module Logical GLSL450 {
+  spv.func @foo(%arg0 : i32) -> i32 "None" {
+    spv.ReturnValue %arg0 : i32
+  }
+
+  spv.func @bar(%arg0 : f32) -> f32 "None" {
+    spv.ReturnValue %arg0 : f32
+  }
+}
+}
+
+// -----
+
+// Resolve conflicting globalVariableOps.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.globalVariable @foo_1
+
+// CHECK-NEXT:     spv.globalVariable @foo
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+}
+
+// -----
+
+// Resolve conflicting globalVariableOp and specConstantOp.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.globalVariable @foo_1
+
+// CHECK-NEXT:     spv.specConstant @foo
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+
+spv.module Logical GLSL450 {
+  spv.specConstant @foo = -5 : i32
+}
+}
+
+// -----
+
+// Resolve conflicting specConstantOp and globalVariableOp.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.specConstant @foo_1
+
+// CHECK-NEXT:     spv.globalVariable @foo
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.specConstant @foo = -5 : i32
+}
+
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+}
+
+// -----
+
+// Resolve conflicting globalVariableOp and specConstantCompositeOp.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.globalVariable @foo_1
+
+// CHECK-NEXT:     spv.specConstant @bar
+// CHECK-NEXT:     spv.specConstantComposite @foo (@bar, @bar)
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+
+spv.module Logical GLSL450 {
+  spv.specConstant @bar = -5 : i32
+  spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
+}
+}
+
+// -----
+
+// Resolve conflicting globalVariableOp and specConstantComposite.
+
+// CHECK:      module {
+// CHECK-NEXT:   spv.module Logical GLSL450 {
+// CHECK-NEXT:     spv.specConstant @bar
+// CHECK-NEXT:     spv.specConstantComposite @foo_1 (@bar, @bar)
+
+// CHECK-NEXT:     spv.globalVariable @foo
+// CHECK-NEXT: }
+
+module {
+spv.module Logical GLSL450 {
+  spv.specConstant @bar = -5 : i32
+  spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32>
+}
+
+spv.module Logical GLSL450 {
+  spv.globalVariable @foo bind(1, 0) : !spv.ptr<f32, Input>
+}
+}
index 204a633..6c74d2f 100644 (file)
@@ -2,6 +2,7 @@
 add_mlir_library(MLIRSPIRVTestPasses
   TestAvailability.cpp
   TestEntryPointAbi.cpp
+  TestModuleCombiner.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
@@ -14,5 +15,6 @@ add_mlir_library(MLIRSPIRVTestPasses
   MLIRIR
   MLIRPass
   MLIRSPIRV
+  MLIRSPIRVModuleCombiner
   MLIRSupport
   )
diff --git a/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp
new file mode 100644 (file)
index 0000000..b321954
--- /dev/null
@@ -0,0 +1,48 @@
+//===- TestModuleCombiner.cpp - Pass to test SPIR-V module combiner lib ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/ModuleCombiner.h"
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+class TestModuleCombinerPass
+    : public PassWrapper<TestModuleCombinerPass,
+                         OperationPass<mlir::ModuleOp>> {
+public:
+  TestModuleCombinerPass() = default;
+  TestModuleCombinerPass(const TestModuleCombinerPass &) {}
+  void runOnOperation() override;
+
+private:
+  mlir::spirv::OwningSPIRVModuleRef combinedModule;
+};
+} // namespace
+
+void TestModuleCombinerPass::runOnOperation() {
+  auto modules = llvm::to_vector<4>(getOperation().getOps<spirv::ModuleOp>());
+
+  OpBuilder combinedModuleBuilder(modules[0]);
+  combinedModule = spirv::combine(modules, combinedModuleBuilder, nullptr);
+
+  for (spirv::ModuleOp module : modules)
+    module.erase();
+}
+
+namespace mlir {
+void registerTestSpirvModuleCombinerPass() {
+  PassRegistration<TestModuleCombinerPass> registration(
+      "test-spirv-module-combiner", "Tests SPIR-V module combiner library");
+}
+} // namespace mlir
index 196bda6..b5506a5 100644 (file)
@@ -79,6 +79,7 @@ void registerTestPrintNestingPass();
 void registerTestRecursiveTypesPass();
 void registerTestReducer();
 void registerTestSpirvEntryPointABIPass();
+void registerTestSpirvModuleCombinerPass();
 void registerTestSCFUtilsPass();
 void registerTestTraitsPass();
 void registerTestVectorConversions();
@@ -140,6 +141,7 @@ void registerTestPasses() {
   registerTestReducer();
   registerTestGpuParallelLoopMappingPass();
   registerTestSpirvEntryPointABIPass();
+  registerTestSpirvModuleCombinerPass();
   registerTestSCFUtilsPass();
   registerTestTraitsPass();
   registerTestVectorConversions();