NetSpec: don't require lists to specify single-element repeated fields
authorJeff Donahue <jeff.donahue@gmail.com>
Sat, 22 Aug 2015 00:29:06 +0000 (17:29 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Thu, 3 Sep 2015 23:31:28 +0000 (16:31 -0700)
python/caffe/net_spec.py
python/caffe/test/test_net_spec.py

index 77a0e00..93fc019 100644 (file)
@@ -56,8 +56,14 @@ def to_proto(*tops):
 def assign_proto(proto, name, val):
     """Assign a Python object to a protobuf message, based on the Python
     type (in recursive fashion). Lists become repeated fields/messages, dicts
-    become messages, and other types are assigned directly."""
-
+    become messages, and other types are assigned directly. For convenience,
+    repeated fields whose values are not lists are converted to single-element
+    lists; e.g., `my_repeated_int_field=3` is converted to
+    `my_repeated_int_field=[3]`."""
+
+    is_repeated_field = hasattr(getattr(proto, name), 'extend')
+    if is_repeated_field and not isinstance(val, list):
+        val = [val]
     if isinstance(val, list):
         if isinstance(val[0], dict):
             for item in val:
index b4595e6..fee3c0a 100644 (file)
@@ -43,8 +43,7 @@ def anon_lenet(batch_size):
 
 def silent_net():
     n = caffe.NetSpec()
-    n.data, n.data2 = L.DummyData(shape=[dict(dim=[3]), dict(dim=[4, 2])],
-                                  ntop=2)
+    n.data, n.data2 = L.DummyData(shape=dict(dim=3), ntop=2)
     n.silence_data = L.Silence(n.data, ntop=0)
     n.silence_data2 = L.Silence(n.data2, ntop=0)
     return n.to_proto()