Add support for getting TensorProto argument (#18364)
authorYinghai Lu <yinghai@fb.com>
Wed, 3 Apr 2019 03:52:58 +0000 (20:52 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 03:58:28 +0000 (20:58 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18364

att

Reviewed By: bddppq

Differential Revision: D14584784

fbshipit-source-id: 03f9207d5cf4f7f4b812428a931edbcdcb21ca8d

caffe2/utils/proto_utils.cc

index 40cc1b8..c1d6458 100644 (file)
@@ -304,7 +304,14 @@ bool SupportsLosslessConversion(const InputType& value) {
   return static_cast<InputType>(static_cast<TargetType>(value)) == value;
 }
 }
+bool operator==(const TensorProto& l, const TensorProto& r) {
+  return l.SerializeAsString() == r.SerializeAsString();
+}
 
+std::ostream& operator<<(std::ostream& output, const TensorProto& n) {
+  output << n.SerializeAsString();
+  return output;
+}
 bool operator==(const NetDef& l, const NetDef& r) {
   return l.SerializeAsString() == r.SerializeAsString();
 }
@@ -404,6 +411,7 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(uint16_t, ints, true)
 INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
 INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
 INSTANTIATE_GET_REPEATED_ARGUMENT(NetDef, nets, false)
+INSTANTIATE_GET_REPEATED_ARGUMENT(TensorProto, tensors, false)
 #undef INSTANTIATE_GET_REPEATED_ARGUMENT
 
 #define CAFFE2_MAKE_SINGULAR_ARGUMENT(T, fieldname)                      \