From 509e51bc809032bd3d9443bd4afc152fb5eaaf93 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 26 Feb 2018 12:33:17 -0800 Subject: [PATCH] Maintain a cache of output dtypes of ops in TFE_Context. PiperOrigin-RevId: 187062992 --- tensorflow/c/eager/c_api.cc | 20 ++++++++++++++++++++ tensorflow/c/eager/runtime.cc | 15 ++++++++++++--- tensorflow/c/eager/runtime.h | 6 ++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index c27a712..bebb63c7 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" @@ -823,6 +824,25 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, delete kernel; return; } + // Update output_dtypes inside `kernel`. + const tensorflow::OpDef* op_def = nullptr; + const tensorflow::FunctionDef* function_def = + ctx->func_lib_def.Find(ndef.op()); + if (function_def != nullptr) { + op_def = &(function_def->signature()); + } + if (op_def == nullptr) { + status->status = OpDefForOp(ndef.op().c_str(), &op_def); + if (!status->status.ok()) { + return; + } + } + tensorflow::DataTypeVector input_dtypes; + status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes, + kernel->output_dtypes()); + if (!status->status.ok()) { + return; + } tensorflow::mutex_lock ml(ctx->cache_mu); tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); } diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index f77a937..4bf24fe 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -41,17 +41,26 @@ const uint32 kIsList = 1U << 31; } // namespace +Status OpDefForOp(const char* op_name, const OpDef** op_def) { + const OpRegistrationData* op_reg_data = nullptr; + Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); + if (s.ok()) { + *op_def = &op_reg_data->op_def; + } + return s; +} + Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { mutex_lock l(g_op_name_to_attr_type_map_lock); *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name); if (*out != nullptr) return Status::OK(); - const OpRegistrationData* op_reg_data = nullptr; - Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); + const OpDef* op_def = nullptr; + Status s = OpDefForOp(op_name, &op_def); if (!s.ok()) return s; std::unique_ptr m(new AttrTypeMap); // TODO(agarwal): Avoid having to create this "registry" at runtime, // perhaps can be done at op registration time? - for (const auto& attr : op_reg_data->op_def.attr()) { + for (const auto& attr : op_def->attr()) { string type = attr.type(); const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0); if (is_list) { diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index 4d20b52..7fede4d 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -39,6 +39,9 @@ namespace tensorflow { // represent the TF_AttrType type of the values in the list. typedef std::unordered_map AttrTypeMap; +// Look up OpDef for `op_name`. +Status OpDefForOp(const char* op_name, const OpDef** op_def); + // Returns the AttrTypeMap for the TensorFlow operation named op_name. Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out); @@ -180,12 +183,15 @@ class KernelAndDevice { const OpKernel* kernel() const { return kernel_.get(); } + DataTypeVector* output_dtypes() { return &output_dtypes_; } + private: std::unique_ptr kernel_; Device* device_; FunctionLibraryRuntime* flib_; checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; Rendezvous* rendez_; + DataTypeVector output_dtypes_; }; } // namespace tensorflow -- 2.7.4