Maintain a cache of output dtypes of ops in TFE_Context.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 26 Feb 2018 20:33:17 +0000 (12:33 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
PiperOrigin-RevId: 187062992

tensorflow/c/eager/c_api.cc
tensorflow/c/eager/runtime.cc
tensorflow/c/eager/runtime.h

index c27a712..bebb63c 100644 (file)
@@ -33,6 +33,7 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/device_mgr.h"
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/rendezvous.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.h"
@@ -823,6 +824,25 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
       delete kernel;
       return;
     }
+    // Update output_dtypes inside `kernel`.
+    const tensorflow::OpDef* op_def = nullptr;
+    const tensorflow::FunctionDef* function_def =
+        ctx->func_lib_def.Find(ndef.op());
+    if (function_def != nullptr) {
+      op_def = &(function_def->signature());
+    }
+    if (op_def == nullptr) {
+      status->status = OpDefForOp(ndef.op().c_str(), &op_def);
+      if (!status->status.ok()) {
+        return;
+      }
+    }
+    tensorflow::DataTypeVector input_dtypes;
+    status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes,
+                                       kernel->output_dtypes());
+    if (!status->status.ok()) {
+      return;
+    }
     tensorflow::mutex_lock ml(ctx->cache_mu);
     tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
   }
index f77a937..4bf24fe 100644 (file)
@@ -41,17 +41,26 @@ const uint32 kIsList = 1U << 31;
 
 }  // namespace
 
+Status OpDefForOp(const char* op_name, const OpDef** op_def) {
+  const OpRegistrationData* op_reg_data = nullptr;
+  Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
+  if (s.ok()) {
+    *op_def = &op_reg_data->op_def;
+  }
+  return s;
+}
+
 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
   mutex_lock l(g_op_name_to_attr_type_map_lock);
   *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name);
   if (*out != nullptr) return Status::OK();
-  const OpRegistrationData* op_reg_data = nullptr;
-  Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data);
+  const OpDef* op_def = nullptr;
+  Status s = OpDefForOp(op_name, &op_def);
   if (!s.ok()) return s;
   std::unique_ptr<AttrTypeMap> m(new AttrTypeMap);
   // TODO(agarwal): Avoid having to create this "registry" at runtime,
   // perhaps can be done at op registration time?
-  for (const auto& attr : op_reg_data->op_def.attr()) {
+  for (const auto& attr : op_def->attr()) {
     string type = attr.type();
     const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0);
     if (is_list) {
index 4d20b52..7fede4d 100644 (file)
@@ -39,6 +39,9 @@ namespace tensorflow {
 // represent the TF_AttrType type of the values in the list.
 typedef std::unordered_map<string, uint32> AttrTypeMap;
 
+// Look up OpDef for `op_name`.
+Status OpDefForOp(const char* op_name, const OpDef** op_def);
+
 // Returns the AttrTypeMap for the TensorFlow operation named op_name.
 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
 
@@ -180,12 +183,15 @@ class KernelAndDevice {
 
   const OpKernel* kernel() const { return kernel_.get(); }
 
+  DataTypeVector* output_dtypes() { return &output_dtypes_; }
+
  private:
   std::unique_ptr<OpKernel> kernel_;
   Device* device_;
   FunctionLibraryRuntime* flib_;
   checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
   Rendezvous* rendez_;
+  DataTypeVector output_dtypes_;
 };
 
 }  // namespace tensorflow