From e493abcf55a35812d15e16477958baa4bdc92707 Mon Sep 17 00:00:00 2001 From: KareemErgawy Date: Fri, 28 May 2021 08:49:45 +0200 Subject: [PATCH] [MLIR][SPIRV] Use getAsmResultName(...) hook for ConstantOp. Implements better naming for results of `spv.Constant` ops by making it inherit from OpAsmOpInterface and implementing the associated getAsmResultName(...) hook. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D103152 --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h | 1 + .../mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td | 4 ++- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 40 ++++++++++++++++++++++ mlir/test/Dialect/SPIRV/IR/asm-op-interface.mlir | 28 +++++++++++++++ mlir/test/Dialect/SPIRV/IR/memory-ops.mlir | 3 +- 5 files changed, 74 insertions(+), 2 deletions(-) create mode 100644 mlir/test/Dialect/SPIRV/IR/asm-op-interface.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h index 2de2bc0..410fc94 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/Support/PointerLikeTypeTraits.h" diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td index 0d787dd..c185cf0 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -16,6 +16,7 @@ #define MLIR_DIALECT_SPIRV_IR_STRUCTURE_OPS include "mlir/Dialect/SPIRV/IR/SPIRVBase.td" +include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -67,7 +68,8 @@ def SPV_AddressOfOp : SPV_Op<"mlir.addressof", [InFunctionScope, NoSideEffect]> // ----- -def SPV_ConstantOp : SPV_Op<"Constant", [ConstantLike, NoSideEffect]> { +def SPV_ConstantOp : SPV_Op<"Constant", + [ConstantLike, DeclareOpInterfaceMethods, NoSideEffect]> { let summary = "The op that declares a SPIR-V normal constant"; let description = [{ diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index bf7c0e4..782d7dd 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1650,6 +1650,46 @@ spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, llvm_unreachable("unimplemented types for ConstantOp::getOne()"); } +void mlir::spirv::ConstantOp::getAsmResultNames( + llvm::function_ref setNameFn) { + Type type = getType(); + + SmallString<32> specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << "cst"; + + IntegerType intTy = type.dyn_cast(); + + if (IntegerAttr intCst = value().dyn_cast()) { + if (intTy && intTy.getWidth() == 1) { + return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); + } + + if (intTy.isSignless()) { + specialName << intCst.getInt(); + } else { + specialName << intCst.getSInt(); + } + } + + if (intTy || type.isa()) { + specialName << '_' << type; + } + + if (auto vecType = type.dyn_cast()) { + specialName << "_vec_"; + specialName << vecType.getDimSize(0); + + Type elementType = vecType.getElementType(); + + if (elementType.isa() || elementType.isa()) { + specialName << "x" << elementType; + } + } + + setNameFn(getResult(), specialName.str()); +} + //===----------------------------------------------------------------------===// // spv.EntryPoint //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/asm-op-interface.mlir b/mlir/test/Dialect/SPIRV/IR/asm-op-interface.mlir new file mode 100644 index 0000000..a53f061 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/IR/asm-op-interface.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt %s | FileCheck %s + +func @const() -> () { + // CHECK: %true + %0 = spv.Constant true + // CHECK: %false + %1 = spv.Constant false + + // CHECK: %cst42_i32 + %2 = spv.Constant 42 : i32 + // CHECK: %cst-42_i32 + %-2 = spv.Constant -42 : i32 + // CHECK: %cst43_i64 + %3 = spv.Constant 43 : i64 + + // CHECK: %cst_f32 + %4 = spv.Constant 0.5 : f32 + // CHECK: %cst_f64 + %5 = spv.Constant 0.5 : f64 + + // CHECK: %cst_vec_3xi32 + %6 = spv.Constant dense<[1, 2, 3]> : vector<3xi32> + + // CHECK: %cst + %8 = spv.Constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>> + + return +} diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir index 22bbf31..62a43b0f 100644 --- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir @@ -487,8 +487,9 @@ func @variable(%arg0: f32) -> () { // ----- func @variable_init_normal_constant() -> () { + // CHECK: %[[cst:.*]] = spv.Constant %0 = spv.Constant 4.0 : f32 - // CHECK: spv.Variable init(%0) : !spv.ptr + // CHECK: spv.Variable init(%[[cst]]) : !spv.ptr %1 = spv.Variable init(%0) : !spv.ptr return } -- 2.7.4