A few misc tweaks to TFE APIs.
authorMingsheng Hong <hongm@google.com>
Fri, 2 Feb 2018 21:53:08 +0000 (13:53 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Feb 2018 22:02:02 +0000 (14:02 -0800)
PiperOrigin-RevId: 184329345

tensorflow/c/eager/c_api.cc
tensorflow/c/eager/c_api_internal.h
tensorflow/c/eager/runtime.cc
tensorflow/c/eager/runtime.h
tensorflow/c/eager/runtime_test.cc

index fd6cecd..d5b9bff 100644 (file)
@@ -290,7 +290,7 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
     return TF_ATTR_INT;  // The compiler requires that we return something.
   }
   status->status =
-      tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list);
+      tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list);
   return ret;
 }
 
index dda6847..f2abffb 100644 (file)
@@ -107,6 +107,8 @@ struct TFE_TensorHandle {
 };
 
 struct TFE_Op {
+  // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
+  // primitive operation.
   TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
       : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
 
index 3a9951e..12abfcb 100644 (file)
@@ -86,10 +86,9 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
   return Status::OK();
 }
 
-Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
+Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
                       TF_AttrType* out, unsigned char* is_list) {
-  CHECK(m);
-  auto* t = gtl::FindOrNull(*m, attr_name);
+  auto* t = gtl::FindOrNull(m, attr_name);
   if (t == nullptr) {
     return errors::InvalidArgument("Attribute '", attr_name,
                                    "' does not exist for this operation");
@@ -173,14 +172,14 @@ void CombineUnordered(const tensorflow::Fprint128& a,
   b->high64 += a.high64;
 }
 
-inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s,
+inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
                                             const tensorflow::Fprint128& b) {
   // TODO(agarwal): avoid ToString().
   tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString());
   return FingerprintCat128(a, b);
 }
 
-inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) {
+inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
   return CacheKeyHelper(s, {b, b});
 }
 
index e28a416..4d20b52 100644 (file)
@@ -43,7 +43,7 @@ typedef std::unordered_map<string, uint32> AttrTypeMap;
 Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
 
 // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
-Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
+Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
                       TF_AttrType* out, unsigned char* is_list);
 
 // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
index 2ccca66..6431530 100644 (file)
@@ -63,17 +63,17 @@ TEST(AttrTypeMap, Lookup) {
 
   TF_AttrType t;
   unsigned char is_list = 1;
-  s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
+  s = AttrTypeByName(*m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
   EXPECT_FALSE(s.ok());
   EXPECT_NE(is_list, 0);
-  s = AttrTypeByName(m, "transpose_a", &t, &is_list);
+  s = AttrTypeByName(*m, "transpose_a", &t, &is_list);
   ASSERT_TRUE(s.ok()) << s;
   EXPECT_EQ(TF_ATTR_BOOL, t);
   EXPECT_EQ(is_list, 0);
 
   s = AttrTypeMapForOp("Squeeze", &m);
   ASSERT_TRUE(s.ok()) << s;
-  s = AttrTypeByName(m, "squeeze_dims", &t, &is_list);
+  s = AttrTypeByName(*m, "squeeze_dims", &t, &is_list);
   ASSERT_TRUE(s.ok()) << s;
   EXPECT_EQ(TF_ATTR_INT, t);
   EXPECT_NE(is_list, 0);