From 89418ddcb50034bbd8f631cbd71514721576426d Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 5 Sep 2022 11:54:19 +0000 Subject: [PATCH] Plumb write_bytecode to the Python API This adds a `write_bytecode` method to the Operation class. The method takes a file handle and writes the binary blob to it. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D133210 --- mlir/include/mlir-c/IR.h | 5 +++++ mlir/lib/Bindings/Python/IRCore.cpp | 17 +++++++++++++++++ mlir/lib/Bindings/Python/IRModule.h | 3 +++ mlir/lib/CAPI/IR/CMakeLists.txt | 1 + mlir/lib/CAPI/IR/IR.cpp | 8 +++++++- mlir/test/python/ir/operation.py | 12 ++++++++++++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 + 7 files changed, 46 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 2d38700..daf097d 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -521,6 +521,11 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirStringCallback callback, void *userData); +/// Same as mlirOperationPrint but writing the bytecode format out. +MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, + MlirStringCallback callback, + void *userData); + /// Prints an operation to stderr. MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index e83e993..389969b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -119,6 +119,13 @@ Returns: argument. )"; +static const char kOperationPrintBytecodeDocstring[] = + R"(Write the bytecode form of the operation to a file like object. + +Args: + file: The file like object to write to. +)"; + static const char kOperationStrDunderDocstring[] = R"(Gets the assembly form of the operation with default options. @@ -1022,6 +1029,14 @@ void PyOperationBase::print(py::object fileObject, bool binary, mlirOpPrintingFlagsDestroy(flags); } +void PyOperationBase::writeBytecode(py::object fileObject) { + PyOperation &operation = getOperation(); + operation.checkValid(); + PyFileAccumulator accum(fileObject, /*binary=*/true); + mlirOperationWriteBytecode(operation, accum.getCallback(), + accum.getUserData()); +} + py::object PyOperationBase::getAsm(bool binary, llvm::Optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, @@ -2627,6 +2642,8 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, py::arg("assume_verified") = false, kOperationPrintDocstring) + .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"), + kOperationPrintBytecodeDocstring) .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. py::arg("binary") = false, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 246b244..ad783c6 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -512,6 +512,9 @@ public: bool printGenericOpForm, bool useLocalScope, bool assumeVerified); + // Implement the bound 'writeBytecode' method. + void writeBytecode(pybind11::object fileObject); + /// Moves the operation before or after the other operation. void moveAfter(PyOperationBase &other); void moveBefore(PyOperationBase &other); diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt index 320ed07..36f2852 100644 --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIIR Support.cpp LINK_LIBS PUBLIC + MLIRBytecodeWriter MLIRIR MLIRParser MLIRSupport diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 435c974..98a3ff3 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -10,6 +10,7 @@ #include "mlir-c/Support.h" #include "mlir/AsmParser/AsmParser.h" +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" @@ -23,7 +24,6 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser/Parser.h" -#include "llvm/Support/Debug.h" #include using namespace mlir; @@ -485,6 +485,12 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, unwrap(op)->print(stream, *unwrap(flags)); } +void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + writeBytecodeToFile(unwrap(op), stream); +} + void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } bool mlirOperationVerify(MlirOperation op) { diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index b328612..9190c30 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -566,6 +566,18 @@ def testOperationPrint(): print(str_value.__class__) print(f.getvalue()) + # Test roundtrip to bytecode. + bytecode_stream = io.BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + assert bytecode.startswith(b'ML\xefR'), "Expected bytecode to start with MLïR" + module_roundtrip = Module.parse(bytecode, ctx) + f = io.StringIO() + module_roundtrip.operation.print(file=f) + roundtrip_value = f.getvalue() + assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode" + + # Test print to binary file. f = io.BytesIO() # CHECK: diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 22b94b9..33fdb94 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -398,6 +398,7 @@ mlir_c_api_cc_library( includes = ["include"], deps = [ ":AsmParser", + ":BytecodeWriter", ":ConversionPassIncGen", ":FuncDialect", ":InferTypeOpInterface", -- 2.7.4