[MLIR][SPIRV] Use getAsmResultName(...) hook for ConstantOp.
authorKareemErgawy <kareem.ergawy@tomtom.com>
Fri, 28 May 2021 06:49:45 +0000 (08:49 +0200)
committerKareemErgawy <kareem.ergawy@tomtom.com>
Fri, 28 May 2021 07:28:02 +0000 (09:28 +0200)
Implements better naming for results of `spv.Constant` ops by making it
inherit from OpAsmOpInterface and implementing the associated
getAsmResultName(...) hook.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D103152

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/IR/asm-op-interface.mlir [new file with mode: 0644]
mlir/test/Dialect/SPIRV/IR/memory-ops.mlir

index 2de2bc0..410fc94 100644 (file)
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
index 0d787dd..c185cf0 100644 (file)
@@ -16,6 +16,7 @@
 #define MLIR_DIALECT_SPIRV_IR_STRUCTURE_OPS
 
 include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
+include "mlir/IR/OpAsmInterface.td"
 include "mlir/IR/SymbolInterfaces.td"
 include "mlir/Interfaces/CallInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -67,7 +68,8 @@ def SPV_AddressOfOp : SPV_Op<"mlir.addressof", [InFunctionScope, NoSideEffect]>
 
 // -----
 
-def SPV_ConstantOp : SPV_Op<"Constant", [ConstantLike, NoSideEffect]> {
+def SPV_ConstantOp : SPV_Op<"Constant",
+    [ConstantLike, DeclareOpInterfaceMethods<OpAsmOpInterface>, NoSideEffect]> {
   let summary = "The op that declares a SPIR-V normal constant";
 
   let description = [{
index bf7c0e4..782d7dd 100644 (file)
@@ -1650,6 +1650,46 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
   llvm_unreachable("unimplemented types for ConstantOp::getOne()");
 }
 
+void mlir::spirv::ConstantOp::getAsmResultNames(
+    llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
+  Type type = getType();
+
+  SmallString<32> specialNameBuffer;
+  llvm::raw_svector_ostream specialName(specialNameBuffer);
+  specialName << "cst";
+
+  IntegerType intTy = type.dyn_cast<IntegerType>();
+
+  if (IntegerAttr intCst = value().dyn_cast<IntegerAttr>()) {
+    if (intTy && intTy.getWidth() == 1) {
+      return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
+    }
+
+    if (intTy.isSignless()) {
+      specialName << intCst.getInt();
+    } else {
+      specialName << intCst.getSInt();
+    }
+  }
+
+  if (intTy || type.isa<FloatType>()) {
+    specialName << '_' << type;
+  }
+
+  if (auto vecType = type.dyn_cast<VectorType>()) {
+    specialName << "_vec_";
+    specialName << vecType.getDimSize(0);
+
+    Type elementType = vecType.getElementType();
+
+    if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
+      specialName << "x" << elementType;
+    }
+  }
+
+  setNameFn(getResult(), specialName.str());
+}
+
 //===----------------------------------------------------------------------===//
 // spv.EntryPoint
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/asm-op-interface.mlir b/mlir/test/Dialect/SPIRV/IR/asm-op-interface.mlir
new file mode 100644 (file)
index 0000000..a53f061
--- /dev/null
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+func @const() -> () {
+  // CHECK: %true
+  %0 = spv.Constant true
+  // CHECK: %false
+  %1 = spv.Constant false
+
+  // CHECK: %cst42_i32
+  %2 = spv.Constant 42 : i32
+  // CHECK: %cst-42_i32
+  %-2 = spv.Constant -42 : i32
+  // CHECK: %cst43_i64
+  %3 = spv.Constant 43 : i64
+
+  // CHECK: %cst_f32
+  %4 = spv.Constant 0.5 : f32
+  // CHECK: %cst_f64
+  %5 = spv.Constant 0.5 : f64
+
+  // CHECK: %cst_vec_3xi32 
+  %6 = spv.Constant dense<[1, 2, 3]> : vector<3xi32>
+
+  // CHECK: %cst
+  %8 = spv.Constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>>
+
+  return
+}
index 22bbf31..62a43b0 100644 (file)
@@ -487,8 +487,9 @@ func @variable(%arg0: f32) -> () {
 // -----
 
 func @variable_init_normal_constant() -> () {
+  // CHECK: %[[cst:.*]] = spv.Constant
   %0 = spv.Constant 4.0 : f32
-  // CHECK: spv.Variable init(%0) : !spv.ptr<f32, Function>
+  // CHECK: spv.Variable init(%[[cst]]) : !spv.ptr<f32, Function>
   %1 = spv.Variable init(%0) : !spv.ptr<f32, Function>
   return
 }