caffe2/python/task: added __repr__ methods to all task definitions (#15250)
authorTristan Rice <tristanr@fb.com>
Mon, 17 Dec 2018 23:59:45 +0000 (15:59 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 18 Dec 2018 00:02:16 +0000 (16:02 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15250

This adds `__repr__` methods to all of the classes under task.py. This makes the objects much easier to interact with when using them in an interactive manner, such as in a Jupyter notebook.

The default `__repr__` method just returns the object ID which is very unhelpful.

Reviewed By: hanli0612

Differential Revision: D13475758

fbshipit-source-id: 6e1b166ec35163b9776c797b6a2e0d002560cd29

caffe2/python/task.py
caffe2/python/task_test.py [new file with mode: 0644]

index 161ba4f..eb7ad4e 100644 (file)
@@ -52,6 +52,10 @@ class Cluster(object):
     def node_kwargs(self):
         return self._node_kwargs
 
+    def __repr__(self):
+        return "Cluster(nodes={}, node_kwargs={})".format(
+            self.nodes(), self.node_kwargs())
+
 
 @context.define_context(allow_default=True)
 class Node(object):
@@ -85,6 +89,9 @@ class Node(object):
     def __str__(self):
         return self._name
 
+    def __repr__(self):
+        return "Node(name={}, kwargs={})".format(self._name, self._kwargs)
+
     def kwargs(self):
         return self._kwargs
 
@@ -345,6 +352,10 @@ class TaskGroup(object):
     def workspace_type(self):
         return self._workspace_type
 
+    def __repr__(self):
+        return "TaskGroup(tasks={}, workspace_type={}, remote_nets={})".format(
+            self.tasks(), self.workspace_type(), self.remote_nets())
+
 
 class TaskOutput(object):
     """
@@ -389,6 +400,9 @@ class TaskOutput(object):
         else:
             return fetched_vals
 
+    def __repr__(self):
+        return "TaskOutput(names={}, values={})".format(self.names, self._values)
+
 
 def final_output(blob_or_record):
     """
@@ -424,6 +438,9 @@ class TaskOutputList(object):
             offset += num
         assert offset == len(values), 'Wrong number of output values.'
 
+    def __repr__(self):
+        return "TaskOutputList(outputs={})".format(self.outputs)
+
 
 @context.define_context()
 class Task(object):
@@ -625,6 +642,10 @@ class Task(object):
         self.get_step()
         self._already_used = True
 
+    def __repr__(self):
+        return "Task(name={}, node={}, outputs={})".format(
+            self.name, self.node, self.outputs())
+
 
 class SetupNets(object):
     """
@@ -668,3 +689,7 @@ class SetupNets(object):
 
     def exit(self, exit_net):
         return self.exit_nets
+
+    def __repr__(self):
+        return "SetupNets(init_nets={}, exit_nets={})".format(
+            self.init_nets, self.exit_nets)
diff --git a/caffe2/python/task_test.py b/caffe2/python/task_test.py
new file mode 100644 (file)
index 0000000..f1c51bc
--- /dev/null
@@ -0,0 +1,24 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+from caffe2.python import task
+
+
+class TestTask(unittest.TestCase):
+    def testRepr(self):
+        cases = [
+            (task.Cluster(), "Cluster(nodes=[], node_kwargs={})"),
+            (task.Node(), "Node(name=local, kwargs={})"),
+            (
+                task.TaskGroup(),
+                "TaskGroup(tasks=[], workspace_type=None, remote_nets=[])",
+            ),
+            (task.TaskOutput([]), "TaskOutput(names=[], values=None)"),
+            (task.Task(), "Task(name=local/task, node=local, outputs=[])"),
+            (task.SetupNets(), "SetupNets(init_nets=None, exit_nets=None)"),
+        ]
+        for obj, want in cases:
+            self.assertEqual(obj.__repr__(), want)