From 9332ce7e8d394f3e40b7def72da7b1f761db2803 Mon Sep 17 00:00:00 2001 From: Mingsheng Hong Date: Fri, 2 Feb 2018 13:53:08 -0800 Subject: [PATCH] A few misc tweaks to TFE APIs. PiperOrigin-RevId: 184329345 --- tensorflow/c/eager/c_api.cc | 2 +- tensorflow/c/eager/c_api_internal.h | 2 ++ tensorflow/c/eager/runtime.cc | 9 ++++----- tensorflow/c/eager/runtime.h | 2 +- tensorflow/c/eager/runtime_test.cc | 6 +++--- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index fd6cecd..d5b9bff 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -290,7 +290,7 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, return TF_ATTR_INT; // The compiler requires that we return something. } status->status = - tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list); + tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list); return ret; } diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index dda6847..f2abffb 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -107,6 +107,8 @@ struct TFE_TensorHandle { }; struct TFE_Op { + // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a + // primitive operation. TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index 3a9951e..12abfcb 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -86,10 +86,9 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { return Status::OK(); } -Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name, +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, TF_AttrType* out, unsigned char* is_list) { - CHECK(m); - auto* t = gtl::FindOrNull(*m, attr_name); + auto* t = gtl::FindOrNull(m, attr_name); if (t == nullptr) { return errors::InvalidArgument("Attribute '", attr_name, "' does not exist for this operation"); @@ -173,14 +172,14 @@ void CombineUnordered(const tensorflow::Fprint128& a, b->high64 += a.high64; } -inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, +inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, const tensorflow::Fprint128& b) { // TODO(agarwal): avoid ToString(). tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString()); return FingerprintCat128(a, b); } -inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) { +inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) { return CacheKeyHelper(s, {b, b}); } diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index e28a416..4d20b52 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -43,7 +43,7 @@ typedef std::unordered_map AttrTypeMap; Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out); // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. -Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name, +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, TF_AttrType* out, unsigned char* is_list); // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through. diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index 2ccca66..6431530 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -63,17 +63,17 @@ TEST(AttrTypeMap, Lookup) { TF_AttrType t; unsigned char is_list = 1; - s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list); + s = AttrTypeByName(*m, "ThisAttribyteCannotPossiblyExist", &t, &is_list); EXPECT_FALSE(s.ok()); EXPECT_NE(is_list, 0); - s = AttrTypeByName(m, "transpose_a", &t, &is_list); + s = AttrTypeByName(*m, "transpose_a", &t, &is_list); ASSERT_TRUE(s.ok()) << s; EXPECT_EQ(TF_ATTR_BOOL, t); EXPECT_EQ(is_list, 0); s = AttrTypeMapForOp("Squeeze", &m); ASSERT_TRUE(s.ok()) << s; - s = AttrTypeByName(m, "squeeze_dims", &t, &is_list); + s = AttrTypeByName(*m, "squeeze_dims", &t, &is_list); ASSERT_TRUE(s.ok()) << s; EXPECT_EQ(TF_ATTR_INT, t); EXPECT_NE(is_list, 0); -- 2.7.4