def node_kwargs(self):
return self._node_kwargs
+ def __repr__(self):
+ return "Cluster(nodes={}, node_kwargs={})".format(
+ self.nodes(), self.node_kwargs())
+
@context.define_context(allow_default=True)
class Node(object):
def __str__(self):
return self._name
+ def __repr__(self):
+ return "Node(name={}, kwargs={})".format(self._name, self._kwargs)
+
def kwargs(self):
return self._kwargs
def workspace_type(self):
return self._workspace_type
+ def __repr__(self):
+ return "TaskGroup(tasks={}, workspace_type={}, remote_nets={})".format(
+ self.tasks(), self.workspace_type(), self.remote_nets())
+
class TaskOutput(object):
"""
else:
return fetched_vals
+ def __repr__(self):
+ return "TaskOutput(names={}, values={})".format(self.names, self._values)
+
def final_output(blob_or_record):
"""
offset += num
assert offset == len(values), 'Wrong number of output values.'
+ def __repr__(self):
+ return "TaskOutputList(outputs={})".format(self.outputs)
+
@context.define_context()
class Task(object):
self.get_step()
self._already_used = True
+ def __repr__(self):
+ return "Task(name={}, node={}, outputs={})".format(
+ self.name, self.node, self.outputs())
+
class SetupNets(object):
"""
def exit(self, exit_net):
return self.exit_nets
+
+ def __repr__(self):
+ return "SetupNets(init_nets={}, exit_nets={})".format(
+ self.init_nets, self.exit_nets)
--- /dev/null
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import unittest
+from caffe2.python import task
+
+
+class TestTask(unittest.TestCase):
+ def testRepr(self):
+ cases = [
+ (task.Cluster(), "Cluster(nodes=[], node_kwargs={})"),
+ (task.Node(), "Node(name=local, kwargs={})"),
+ (
+ task.TaskGroup(),
+ "TaskGroup(tasks=[], workspace_type=None, remote_nets=[])",
+ ),
+ (task.TaskOutput([]), "TaskOutput(names=[], values=None)"),
+ (task.Task(), "Task(name=local/task, node=local, outputs=[])"),
+ (task.SetupNets(), "SetupNets(init_nets=None, exit_nets=None)"),
+ ]
+ for obj, want in cases:
+ self.assertEqual(obj.__repr__(), want)