Started to open source the RL placer.
authorBenoit Steiner <bsteiner@google.com>
Thu, 22 Feb 2018 05:05:42 +0000 (21:05 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Feb 2018 05:09:41 +0000 (21:09 -0800)
PiperOrigin-RevId: 186563773

tensorflow/python/BUILD
tensorflow/python/grappler/cluster.i
tensorflow/python/grappler/cluster_test.py
tensorflow/python/grappler/controller.py [new file with mode: 0644]
tensorflow/python/grappler/graph_placer.py [new file with mode: 0644]
tensorflow/python/grappler/graph_placer_test.py [new file with mode: 0644]
tensorflow/python/grappler/hierarchical_controller.py [new file with mode: 0644]
tensorflow/python/grappler/item.i
tensorflow/python/grappler/item_test.py

index 9b0c800..6a7ece4 100644 (file)
@@ -4593,6 +4593,34 @@ py_test(
     ],
 )
 
+py_library(
+    name = "graph_placer",
+    srcs = [
+        "grappler/controller.py",
+        "grappler/graph_placer.py",
+        "grappler/hierarchical_controller.py",
+    ],
+    deps = [
+        ":python",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
+    name = "graph_placer_test",
+    size = "large",
+    srcs = ["grappler/graph_placer_test.py"],
+    tags = [
+        "grappler",
+        "no_pip",  # graph_placer is not available in pip.
+    ],
+    deps = [
+        ":client_testlib",
+        ":graph_placer",
+        "//tensorflow/python:math_ops",
+    ],
+)
+
 py_test(
     name = "memory_optimizer_test",
     size = "medium",
index 8079cb3..067c821 100644 (file)
@@ -206,7 +206,7 @@ static PyObject* TF_ListDevices(GCluster cluster) {
   return result;
 }
 
-static std::vector<string> TF_ListAvailableOps() {
+static PyObject* TF_ListAvailableOps() {
   tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global();
   std::vector<tensorflow::OpDef> ops;
   registry->GetRegisteredOps(&ops);
@@ -215,7 +215,14 @@ static std::vector<string> TF_ListAvailableOps() {
     op_names.push_back(op.name());
   }
   std::sort(op_names.begin(), op_names.end());
-  return op_names;
+
+  PyGILState_STATE gstate = PyGILState_Ensure();
+  PyObject* result = PyList_New(op_names.size());
+  for (int i = 0; i < op_names.size(); ++i) {
+    PyList_SetItem(result, i, PyString_FromString(op_names[i].c_str()));
+  }
+  PyGILState_Release(gstate);
+  return result;
 }
 
 static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) {
@@ -432,7 +439,7 @@ static GCluster TF_NewVirtualCluster(
     TF_Status* out_status);
 static void TF_ShutdownCluster(GCluster cluster);
 static PyObject* TF_ListDevices(GCluster cluster);
-static std::vector<string> TF_ListAvailableOps();
+static PyObject* TF_ListAvailableOps();
 static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item);
 static float TF_EstimatePerformance(const tensorflow::NamedDevice& device);
 static PyObject* TF_MeasureCosts(
index caae5b1..a3c4c2b 100644 (file)
@@ -131,8 +131,8 @@ class ClusterTest(test.TestCase):
   def testAvailableOps(self):
     with cluster.Provision() as gcluster:
       op_names = gcluster.ListAvailableOps()
-      self.assertTrue(b'Add' in op_names)
-      self.assertTrue(b'MatMul' in op_names)
+      self.assertTrue('Add' in op_names)
+      self.assertTrue('MatMul' in op_names)
       self.assertEqual(op_names, sorted(op_names))
 
   def testSupportDevices(self):
diff --git a/tensorflow/python/grappler/controller.py b/tensorflow/python/grappler/controller.py
new file mode 100644 (file)
index 0000000..5677f4f
--- /dev/null
@@ -0,0 +1,142 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Controller Class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import defaultdict
+
+
+class Controller(object):
+  """Controller class."""
+
+  def __init__(self, item, cluster):
+    """Controller class initializer.
+
+    Args:
+      item: The metagraph to place wrapped in a cluster.
+      cluster: A cluster of devices on which to place the item.
+    """
+    self.item = item
+
+    self._node = {}
+    for node in item.metagraph.graph_def.node:
+      self._node[node.name] = node
+
+    self._fanout = defaultdict(lambda: [])
+    for node in item.metagraph.graph_def.node:
+      for fanin in self._get_node_fanin(node):
+        self._fanout[fanin.name].append(node)
+
+    important_op_names = item.IdentifyImportantOps(sort_topologically=True)
+
+    # List of important ops (these are the ops to place) sorted in topological
+    # order. The order of this collection is deterministic.
+    self.important_ops = []
+    for name in important_op_names:
+      self.important_ops.append(self._node[name])
+
+    self.node_properties = item.GetOpProperties()
+
+    self.cluster = cluster
+    self.devices = cluster.ListDevices()
+
+    self.colocation_constraints = item.GetColocationGroups()
+
+    self.placement_constraints = cluster.GetSupportedDevices(item)
+    for node_name, dev in self.placement_constraints.items():
+      if len(dev) == 1:
+        # Place the node on the supported device
+        node = self._node[node_name]
+        node.device = dev[0]
+        fanout = self.get_node_fanout(node)
+        # Update the fanout of the fanin to bypass the node
+        for fanin in self._get_node_fanin(node):
+          fanout_of_fanin = self.get_node_fanout(fanin)
+          fanout_of_fanin += fanout
+          fanout_of_fanin.remove(node)
+        # Remove node from the list of important ops since we don't need to
+        # place the node.
+        if node in self.important_ops:
+          self.important_ops.remove(node)
+          important_op_names.remove(node.name)
+
+    # List of important op names, in non deterministic order.
+    self.important_op_names = frozenset(important_op_names)
+
+  @property
+  def input_graph_def(self):
+    return self.item.metagraph.graph_def
+
+  @property
+  def num_devices(self):
+    return len(self.devices)
+
+  def get_node_by_name(self, node_name):
+    return self._node[node_name]
+
+  def get_node_fanout(self, node):
+    return self._fanout[node.name]
+
+  def get_placements(self, *args, **kwargs):
+    """Returns: Two TF ops.
+
+    Args:
+      *args: "".
+      **kwargs: "".
+
+    Returns:
+      y_preds: tensor of size [batch_size, num_ops]
+      log_probs: python dict of at least two fields: "sample", "target" each
+      containing a tensor of size [batch_size], corresponding to the log_probs.
+    """
+    raise NotImplementedError
+
+  def eval_placement(self, sess, *args, **kwargs):
+    """At this time, this method evaluates ONLY ONE placement.
+
+    Args:
+      sess: a tf.Session() object used to retrieve cached assignment info.
+      *args: "".
+      **kwargs: "".
+
+    Returns:
+      run_time: scalar
+    """
+    raise NotImplementedError
+
+  def export_placement(self, metagraph):
+    """Annotate the placement onto the specified metagraph.
+
+    Args:
+      metagraph: the metagraph to annotate with the placement.
+    """
+    for node in metagraph.graph_def.node:
+      if node.name in self.important_op_names:
+        node.device = self.get_node_by_name(node.name).device
+
+  # Get the nodes in the immediate fanin of node.
+  # Beware: this doesn't take into account the nodes that may be skipped
+  # since placement constraints force their placement.
+  def _get_node_fanin(self, node):
+    input_ops = []
+    for fanin_name in node.input:
+      if fanin_name[0] == "^":
+        fanin_name = fanin_name[1:]
+      fanin_name = fanin_name.split(":")[0]
+      input_ops.append(self.get_node_by_name(fanin_name))
+    return input_ops
diff --git a/tensorflow/python/grappler/graph_placer.py b/tensorflow/python/grappler/graph_placer.py
new file mode 100644 (file)
index 0000000..2cc3536
--- /dev/null
@@ -0,0 +1,110 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Graph Placer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.grappler import cluster as gcluster
+from tensorflow.python.grappler import hierarchical_controller
+from tensorflow.python.grappler import item as gitem
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.training import training
+
+
+def PlaceGraph(metagraph,
+               cluster=None,
+               allotted_time=3600,
+               hparams=None,
+               verbose=False):
+  """Place the provided metagraph.
+
+  Args:
+    metagraph: the metagraph to place.
+    cluster: an optional set of hardware resource to optimize the placement for.
+      If none is specified, we'll optimize the placement for the hardware
+      available on the local machine.
+    allotted_time: the maximum amount to time in seconds to spend optimizing
+      the placement.
+    hparams: hyperparameters used to fine tune the placer.
+    verbose: prints debug information if True.
+
+  Returns:
+    The placed metagraph.
+  """
+  if cluster is None:
+    cluster = gcluster.Cluster()
+
+  # Optimize the metagraph to speedup the placement
+  rewriter_config = rewriter_config_pb2.RewriterConfig()
+  rewriter_config.optimizers.append("pruning")
+  rewriter_config.optimizers.append("constfold")
+  rewriter_config.optimizers.append("arithmetic")
+  rewriter_config.optimizers.append("dependency")
+  rewriter_config.optimizers.append("pruning")
+  optimized_graph = tf_optimizer.OptimizeGraph(
+      rewriter_config, metagraph, verbose=verbose, cluster=cluster)
+  optimized_metagraph = meta_graph_pb2.MetaGraphDef()
+  optimized_metagraph.CopyFrom(metagraph)
+  optimized_metagraph.graph_def.CopyFrom(optimized_graph)
+
+  item = gitem.Item(optimized_metagraph)
+
+  if hparams is None:
+    hparams = hierarchical_controller.hierarchical_controller_hparams()
+  # We run with a single child
+  hparams.num_children = 1
+
+  with tf_ops.Graph().as_default():
+    # Place all the nodes of the controller on the CPU. We don't want them to
+    # fight for accelerator memory with the model to optimize.
+    with tf_ops.device("/device:CPU:0"):
+      model = hierarchical_controller.HierarchicalController(
+          hparams, item, cluster)
+      ops = model.build_controller()
+      session_creator = training.ChiefSessionCreator()
+      with training.MonitoredSession(session_creator=session_creator) as sess:
+        start_time = time.time()
+        current_time = start_time
+        while current_time - start_time < allotted_time:
+          grouping_actions = model.generate_grouping(sess)
+          input_to_seq2seq = model.create_group_embeddings(
+              grouping_actions, verbose=verbose)
+          model.generate_placement(input_to_seq2seq, sess)
+          try:
+            run_time = model.eval_placement(
+                sess,
+                verbose=verbose)
+          except errors.OpError as e:
+            if verbose:
+              print("Failed to run graph:" + str(e))
+            run_time = hparams.failing_signal
+          updated = model.update_reward(sess, run_time, verbose=verbose)
+          if updated:
+            if verbose:
+              print("Found better placement, with runtime " + str(run_time))
+            model.export_placement(metagraph)
+
+          model.process_reward(sess)
+
+          current_time = time.time()
+
+  return metagraph
diff --git a/tensorflow/python/grappler/graph_placer_test.py b/tensorflow/python/grappler/graph_placer_test.py
new file mode 100644 (file)
index 0000000..9eabe3c
--- /dev/null
@@ -0,0 +1,140 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests the graph placer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from tensorflow.core.protobuf import device_properties_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import meta_graph
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.grappler import cluster
+from tensorflow.python.grappler import graph_placer
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class GraphPlacerTest(test.TestCase):
+
+  @staticmethod
+  def _buildMnist(batch_size=128,
+                  input_size=256,
+                  num_classes=1024,
+                  num_layers=10,
+                  hidden_size=256,
+                  name='mnist'):
+    g = tf_ops.get_default_graph()
+    with g.as_default():
+      ops = {}
+      x = random_ops.random_uniform(
+          [batch_size, input_size], -0.1, 0.1, dtype=dtypes.float32)
+      for layer_id in range(num_layers):
+        with variable_scope.variable_scope('layer_{}'.format(layer_id)):
+          a = input_size if layer_id == 0 else hidden_size
+          b = hidden_size if layer_id < num_layers - 1 else num_classes
+          w = variable_scope.get_variable('w', [a, b])
+          x = math_ops.matmul(x, w)
+          x = nn_ops.relu(x)
+      ops['y_preds'] = math_ops.argmax(x, axis=1)
+
+    train_op = g.get_collection_ref(tf_ops.GraphKeys.TRAIN_OP)
+    train_op.append(ops['y_preds'])
+    return g
+
+  @staticmethod
+  def _buildCluster(num_cpus=1, num_gpus=1):
+    devices = []
+    if num_gpus > 0:
+      device_properties = device_properties_pb2.DeviceProperties(
+          type='GPU',
+          vendor='NVidia',
+          model='GeForce GTX TITAN X',
+          frequency=1076,
+          num_cores=24,
+          environment={'architecture': '5.2',
+                       'cuda': '8000',
+                       'cudnn': '6021'},
+          num_registers=65536,
+          l1_cache_size=24576,
+          l2_cache_size=3145728,
+          shared_memory_size_per_multiprocessor=98304,
+          memory_size=12783648768,
+          bandwidth=336480000)
+      for i in range(num_gpus):
+        devices.append(
+            device_properties_pb2.NamedDevice(
+                properties=device_properties, name='/GPU:' + str(i)))
+
+    assert num_cpus > 0
+    device_properties = device_properties_pb2.DeviceProperties(
+        type='CPU',
+        frequency=2000,
+        num_cores=4,
+        l1_cache_size=32768,
+        l2_cache_size=262144,
+        l3_cache_size=12582912)
+    for i in range(num_cpus):
+      devices.append(
+          device_properties_pb2.NamedDevice(
+              properties=device_properties, name='/CPU:' + str(i)))
+
+    return cluster.Cluster(devices=devices)
+
+  def testBasic(self):
+    """Place a trivial graph."""
+    a = constant_op.constant(10, name='a')
+    b = constant_op.constant(20, name='b')
+    c = math_ops.add_n([a, b], name='c')
+    d = math_ops.add_n([b, c], name='d')
+    train_op = tf_ops.get_collection_ref(tf_ops.GraphKeys.TRAIN_OP)
+    train_op.append(d)
+    mg = meta_graph.create_meta_graph_def(graph=tf_ops.get_default_graph())
+
+    gcluster = cluster.Cluster()
+    placed_mg = graph_placer.PlaceGraph(mg, allotted_time=15, cluster=gcluster)
+
+    self.assertEqual(4, len(placed_mg.graph_def.node))
+    self.assertItemsEqual([node.name for node in placed_mg.graph_def.node],
+                          [node.name for node in mg.graph_def.node])
+
+    available_devices = [device.name for device in gcluster.ListDevices()]
+    for node in placed_mg.graph_def.node:
+      # The constant nodes are optimized away before the placer is run, and
+      # therefore won't be placed.
+      self.assertTrue(not node.device or node.device in available_devices)
+
+  def testMNIST(self):
+    graph = GraphPlacerTest._buildMnist()
+    mg = meta_graph.create_meta_graph_def(graph=graph)
+    gcluster = GraphPlacerTest._buildCluster(num_gpus=1)
+    # Spend 15 seconds trying to optimize the placement of the model. This
+    # should give us enough time to exercise the code, but not enough to find
+    # a good placement, so we'll just check for legality.
+    placed_mg = graph_placer.PlaceGraph(mg, allotted_time=15, cluster=gcluster)
+    self.assertEqual(len(placed_mg.graph_def.node), len(mg.graph_def.node))
+    self.assertItemsEqual([node.name for node in placed_mg.graph_def.node],
+                          [node.name for node in mg.graph_def.node])
+    available_devices = [device.name for device in gcluster.ListDevices()]
+    for node in placed_mg.graph_def.node:
+      self.assertTrue(not node.device or node.device in available_devices)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/grappler/hierarchical_controller.py b/tensorflow/python/grappler/hierarchical_controller.py
new file mode 100644 (file)
index 0000000..655e43e
--- /dev/null
@@ -0,0 +1,1098 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""HierarchicalController Class.
+
+The HierarchicalController encompasses the entire lifecycle of training the
+device placement policy, including generating op embeddings, getting groups for
+each op, placing those groups and running the predicted placements.
+
+Different assignment models can inherit from this class.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+import six
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.grappler.controller import Controller
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.summary import summary
+from tensorflow.python.training import adam
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import learning_rate_decay
+from tensorflow.python.training import training_util
+
+
+class PlacerParams(object):
+  """Class to hold a set of placement parameters as name-value pairs.
+
+  A typical usage is as follows:
+
+  ```python
+  # Create a PlacerParams object specifying names and values of the model
+  # parameters:
+  params = PlacerParams(hidden_size=128, decay_steps=50)
+
+  # The parameters are available as attributes of the PlacerParams object:
+  hparams.hidden_size ==> 128
+  hparams.decay_steps ==> 50
+  ```
+
+  """
+
+  def __init__(self, **kwargs):
+    """Create an instance of `PlacerParams` from keyword arguments.
+
+    The keyword arguments specify name-values pairs for the parameters.
+    The parameter types are inferred from the type of the values passed.
+
+    The parameter names are added as attributes of `PlacerParams` object,
+    and they can be accessed directly with the dot notation `params._name_`.
+
+    Example:
+
+    ```python
+    # Define 1 parameter: 'hidden_size'
+    params = PlacerParams(hidden_size=128)
+    params.hidden_size ==> 128
+    ```
+
+    Args:
+      **kwargs: Key-value pairs where the key is the parameter name and
+        the value is the value for the parameter.
+    """
+    for name, value in six.iteritems(kwargs):
+      self.add_param(name, value)
+
+  def add_param(self, name, value):
+    """Adds {name, value} pair to hyperparameters.
+
+    Args:
+      name: Name of the hyperparameter.
+      value: Value of the hyperparameter. Can be one of the following types:
+        int, float, string, int list, float list, or string list.
+
+    Raises:
+      ValueError: if one of the arguments is invalid.
+    """
+    # Keys in kwargs are unique, but 'name' could be the name of a pre-existing
+    # attribute of this object.  In that case we refuse to use it as a
+    # parameter name.
+    if getattr(self, name, None) is not None:
+      raise ValueError("Parameter name is reserved: %s" % name)
+    setattr(self, name, value)
+
+
+def hierarchical_controller_hparams():
+  """Hyperparameters for hierarchical planner."""
+  return PlacerParams(
+      hidden_size=512,
+      forget_bias_init=1.0,
+      temperature=1.0,
+      logits_std_noise=0.5,
+      stop_noise_step=750,
+      decay_steps=50,
+      max_num_outputs=5,
+      max_output_size=5,
+      tanh_constant=1.0,
+      adj_embed_dim=20,
+      grouping_hidden_size=64,
+      num_groups=None,
+      bi_lstm=True,
+      failing_signal=100,
+      stop_sampling=500,
+      start_with_failing_signal=True,
+      always_update_baseline=False,
+      bl_dec=0.9,
+      grad_bound=1.0,
+      lr=0.1,
+      lr_dec=0.95,
+      start_decay_step=400,
+      optimizer_type="adam",
+      stop_updating_after_steps=1000,
+      name="hierarchical_controller",
+      keep_prob=1.0,
+      reward_function="sqrt",
+      seed=1234,
+      # distributed training params
+      num_children=1)
+
+
+class HierarchicalController(Controller):
+  """HierarchicalController class."""
+
+  def __init__(self, hparams, item, cluster, controller_id=0):
+    """HierarchicalController class initializer.
+
+    Args:
+      hparams: All hyper-parameters.
+      item: The metagraph to place.
+      cluster: The cluster of hardware devices to optimize for.
+      controller_id: the id of the controller in a multi-controller setup.
+    """
+    super(HierarchicalController, self).__init__(item, cluster)
+    self.ctrl_id = controller_id
+    self.hparams = hparams
+
+    if self.hparams.num_groups is None:
+      self.num_groups = min(256, 20 * self.num_devices)
+    else:
+      self.num_groups = self.hparams.num_groups
+
+    # creates self.op_embeddings and self.type_dict
+    self.create_op_embeddings(verbose=False)
+    # TODO(azalia) clean up embedding/group_embedding_size names
+    self.group_emb_size = (
+        2 * self.num_groups + len(self.type_dict) +
+        self.hparams.max_num_outputs * self.hparams.max_output_size)
+    self.embedding_size = self.group_emb_size
+    self.initializer = init_ops.glorot_uniform_initializer(
+        seed=self.hparams.seed)
+
+    with variable_scope.variable_scope(
+        self.hparams.name,
+        initializer=self.initializer,
+        reuse=variable_scope.AUTO_REUSE):
+      # define parameters of feedforward
+      variable_scope.get_variable("w_grouping_ff", [
+          1 + self.hparams.max_num_outputs * self.hparams.max_output_size +
+          self.hparams.adj_embed_dim, self.hparams.grouping_hidden_size
+      ])
+      variable_scope.get_variable(
+          "w_grouping_softmax",
+          [self.hparams.grouping_hidden_size, self.num_groups])
+      if self.hparams.bi_lstm:
+        variable_scope.get_variable("encoder_lstm_forward", [
+            self.embedding_size + self.hparams.hidden_size / 2,
+            2 * self.hparams.hidden_size
+        ])
+        variable_scope.get_variable("encoder_lstm_backward", [
+            self.embedding_size + self.hparams.hidden_size / 2,
+            2 * self.hparams.hidden_size
+        ])
+        variable_scope.get_variable(
+            "device_embeddings", [self.num_devices, self.hparams.hidden_size])
+        variable_scope.get_variable(
+            "decoder_lstm",
+            [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size])
+        variable_scope.get_variable(
+            "device_softmax", [2 * self.hparams.hidden_size, self.num_devices])
+        variable_scope.get_variable("device_go_embedding",
+                                    [1, self.hparams.hidden_size])
+        variable_scope.get_variable(
+            "encoder_forget_bias",
+            shape=1,
+            dtype=dtypes.float32,
+            initializer=init_ops.constant_initializer(
+                self.hparams.forget_bias_init))
+        variable_scope.get_variable(
+            "decoder_forget_bias",
+            shape=1,
+            dtype=dtypes.float32,
+            initializer=init_ops.constant_initializer(
+                self.hparams.forget_bias_init))
+        variable_scope.get_variable(
+            "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size])
+        variable_scope.get_variable(
+            "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size])
+        variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1])
+
+      else:
+        variable_scope.get_variable("encoder_lstm", [
+            self.embedding_size + self.hparams.hidden_size,
+            4 * self.hparams.hidden_size
+        ])
+        variable_scope.get_variable(
+            "device_embeddings", [self.num_devices, self.hparams.hidden_size])
+        variable_scope.get_variable(
+            "decoder_lstm",
+            [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size])
+        variable_scope.get_variable(
+            "device_softmax", [2 * self.hparams.hidden_size, self.num_devices])
+        variable_scope.get_variable("device_go_embedding",
+                                    [1, self.hparams.hidden_size])
+        variable_scope.get_variable(
+            "encoder_forget_bias",
+            shape=1,
+            dtype=dtypes.float32,
+            initializer=init_ops.constant_initializer(
+                self.hparams.forget_bias_init))
+        variable_scope.get_variable(
+            "decoder_forget_bias",
+            shape=1,
+            dtype=dtypes.float32,
+            initializer=init_ops.constant_initializer(
+                self.hparams.forget_bias_init))
+        variable_scope.get_variable(
+            "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size])
+        variable_scope.get_variable(
+            "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size])
+        variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1])
+    seq2seq_input_layer = array_ops.placeholder_with_default(
+        array_ops.zeros([1, self.num_groups, self.group_emb_size],
+                        dtypes.float32),
+        shape=(1, self.num_groups, self.group_emb_size))
+    self.seq2seq_input_layer = seq2seq_input_layer
+
+  def compute_reward(self, run_time):
+    if self.hparams.reward_function == "id":
+      reward = run_time
+    elif self.hparams.reward_function == "sqrt":
+      reward = math.sqrt(run_time)
+    elif self.hparams.reward_function == "log":
+      reward = math.log1p(run_time)
+    else:
+      raise NotImplementedError(
+          "Unrecognized reward function '%s', consider your "
+          "--reward_function flag value." % self.hparams.reward_function)
+    return reward
+
+  def build_controller(self):
+    """RL optimization interface.
+
+    Returns:
+      ops: A dictionary holding handles of the model used for training.
+    """
+
+    self._global_step = training_util.get_or_create_global_step()
+    ops = {}
+    ops["loss"] = 0
+
+    failing_signal = self.compute_reward(self.hparams.failing_signal)
+
+    ctr = {}
+
+    with tf_ops.name_scope("controller_{}".format(self.ctrl_id)):
+      with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
+        ctr["reward"] = {"value": [], "ph": [], "update": []}
+        ctr["ready"] = {"value": [], "ph": [], "update": []}
+        ctr["best_reward"] = {"value": [], "update": []}
+        for i in range(self.hparams.num_children):
+          reward_value = variable_scope.get_local_variable(
+              "reward_{}".format(i),
+              initializer=0.0,
+              dtype=dtypes.float32,
+              trainable=False)
+          reward_ph = array_ops.placeholder(
+              dtypes.float32, shape=(), name="reward_ph_{}".format(i))
+          reward_update = state_ops.assign(
+              reward_value, reward_ph, use_locking=True)
+          ctr["reward"]["value"].append(reward_value)
+          ctr["reward"]["ph"].append(reward_ph)
+          ctr["reward"]["update"].append(reward_update)
+          best_reward = variable_scope.get_local_variable(
+              "best_reward_{}".format(i),
+              initializer=failing_signal,
+              dtype=dtypes.float32,
+              trainable=False)
+          ctr["best_reward"]["value"].append(best_reward)
+          ctr["best_reward"]["update"].append(
+              state_ops.assign(best_reward,
+                               math_ops.minimum(best_reward, reward_update)))
+
+          ready_value = variable_scope.get_local_variable(
+              "ready_{}".format(i),
+              initializer=True,
+              dtype=dtypes.bool,
+              trainable=False)
+          ready_ph = array_ops.placeholder(
+              dtypes.bool, shape=(), name="ready_ph_{}".format(i))
+          ready_update = state_ops.assign(
+              ready_value, ready_ph, use_locking=True)
+          ctr["ready"]["value"].append(ready_value)
+          ctr["ready"]["ph"].append(ready_ph)
+          ctr["ready"]["update"].append(ready_update)
+
+      ctr["grouping_y_preds"], ctr["grouping_log_probs"] = self.get_groupings()
+      summary.histogram(
+          "grouping_actions",
+          array_ops.slice(ctr["grouping_y_preds"]["sample"], [0, 0],
+                          [1, array_ops.shape(self.op_embeddings)[0]]))
+
+      with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
+        ctr["baseline"] = variable_scope.get_local_variable(
+            "baseline",
+            initializer=failing_signal
+            if self.hparams.start_with_failing_signal else 0.0,
+            dtype=dtypes.float32,
+            trainable=False)
+
+      new_baseline = self.hparams.bl_dec * ctr["baseline"] + (
+          1 - self.hparams.bl_dec) * math_ops.reduce_mean(
+              ctr["reward"]["value"])
+      if not self.hparams.always_update_baseline:
+        baseline_mask = math_ops.less(ctr["reward"]["value"], failing_signal)
+        selected_reward = array_ops.boolean_mask(ctr["reward"]["value"],
+                                                 baseline_mask)
+        selected_baseline = control_flow_ops.cond(
+            math_ops.reduce_any(baseline_mask),
+            lambda: math_ops.reduce_mean(selected_reward),
+            lambda: constant_op.constant(0, dtype=dtypes.float32))
+        ctr["pos_reward"] = selected_baseline
+        pos_ = math_ops.less(
+            constant_op.constant(0, dtype=dtypes.float32), selected_baseline)
+        selected_baseline = self.hparams.bl_dec * ctr["baseline"] + (
+            1 - self.hparams.bl_dec) * selected_baseline
+        selected_baseline = control_flow_ops.cond(
+            pos_, lambda: selected_baseline, lambda: ctr["baseline"])
+        new_baseline = control_flow_ops.cond(
+            math_ops.less(self.global_step,
+                          self.hparams.stop_updating_after_steps),
+            lambda: new_baseline, lambda: selected_baseline)
+      ctr["baseline_update"] = state_ops.assign(
+          ctr["baseline"], new_baseline, use_locking=True)
+
+      ctr["y_preds"], ctr["log_probs"] = self.get_placements()
+      summary.histogram("actions", ctr["y_preds"]["sample"])
+      mask = math_ops.less(ctr["reward"]["value"], failing_signal)
+      ctr["loss"] = ctr["reward"]["value"] - ctr["baseline"]
+      ctr["loss"] *= (
+          ctr["log_probs"]["sample"] + ctr["grouping_log_probs"]["sample"])
+
+      selected_loss = array_ops.boolean_mask(ctr["loss"], mask)
+      selected_loss = control_flow_ops.cond(
+          math_ops.reduce_any(mask),
+          lambda: math_ops.reduce_mean(-selected_loss),
+          lambda: constant_op.constant(0, dtype=dtypes.float32))
+
+      ctr["loss"] = control_flow_ops.cond(
+          math_ops.less(self.global_step,
+                        self.hparams.stop_updating_after_steps),
+          lambda: math_ops.reduce_mean(-ctr["loss"]), lambda: selected_loss)
+
+      ctr["reward_s"] = math_ops.reduce_mean(ctr["reward"]["value"])
+      summary.scalar("loss", ctr["loss"])
+      summary.scalar("avg_reward", ctr["reward_s"])
+      summary.scalar("best_reward_so_far", best_reward)
+      summary.scalar(
+          "advantage",
+          math_ops.reduce_mean(ctr["reward"]["value"] - ctr["baseline"]))
+
+    with variable_scope.variable_scope(
+        "optimizer", reuse=variable_scope.AUTO_REUSE):
+      (ctr["train_op"], ctr["lr"], ctr["grad_norm"],
+       ctr["grad_norms"]) = self._get_train_ops(
+           ctr["loss"],
+           tf_ops.get_collection(tf_ops.GraphKeys.TRAINABLE_VARIABLES),
+           self.global_step,
+           grad_bound=self.hparams.grad_bound,
+           lr_init=self.hparams.lr,
+           lr_dec=self.hparams.lr_dec,
+           start_decay_step=self.hparams.start_decay_step,
+           decay_steps=self.hparams.decay_steps,
+           optimizer_type=self.hparams.optimizer_type)
+
+    summary.scalar("gradnorm", ctr["grad_norm"])
+    summary.scalar("lr", ctr["lr"])
+    ctr["summary"] = summary.merge_all()
+    ops["controller"] = ctr
+
+    self.ops = ops
+    return ops
+
+  @property
+  def global_step(self):
+    return self._global_step
+
+  def create_op_embeddings(self, verbose=False):
+    if verbose:
+      print("process input graph for op embeddings")
+    self.num_ops = len(self.important_ops)
+    # topological sort of important nodes
+    topo_order = [op.name for op in self.important_ops]
+
+    # create index to name for topologicaly sorted important nodes
+    name_to_topo_order_index = {}
+    for idx, x in enumerate(topo_order):
+      name_to_topo_order_index[x] = idx
+    self.name_to_topo_order_index = name_to_topo_order_index
+
+    # create adj matrix
+    adj_dict = {}
+    for idx, op in enumerate(self.important_ops):
+      for output_op in self.get_node_fanout(op):
+        output_op_name = output_op.name
+        if output_op_name in self.important_op_names:
+          if name_to_topo_order_index[op.name] not in adj_dict:
+            adj_dict[name_to_topo_order_index[op.name]] = []
+          adj_dict[name_to_topo_order_index[op.name]].extend(
+              [name_to_topo_order_index[output_op_name], 1])
+          if output_op_name not in adj_dict:
+            adj_dict[name_to_topo_order_index[output_op_name]] = []
+          adj_dict[name_to_topo_order_index[output_op_name]].extend(
+              [name_to_topo_order_index[op.name], -1])
+
+    # get op_type op_output_shape, and adj info
+    output_embed_dim = (self.hparams.max_num_outputs *
+                        self.hparams.max_output_size)
+
+    # TODO(bsteiner): don't filter based on used ops so that we can generalize
+    # to models that use other types of ops.
+    used_ops = set()
+    for node in self.important_ops:
+      op_type = str(node.op)
+      used_ops.add(op_type)
+
+    self.type_dict = {}
+    for op_type in self.cluster.ListAvailableOps():
+      if op_type in used_ops:
+        self.type_dict[op_type] = len(self.type_dict)
+
+    op_types = np.zeros([self.num_ops], dtype=np.int32)
+    op_output_shapes = np.full(
+        [self.num_ops, output_embed_dim], -1.0, dtype=np.float32)
+    for idx, node in enumerate(self.important_ops):
+      op_types[idx] = self.type_dict[node.op]
+      # output shape
+      op_name = node.name
+      for i, output_prop in enumerate(self.node_properties[op_name]):
+        if output_prop.shape.__str__() == "<unknown>":
+          continue
+        shape = output_prop.shape
+        for j, dim in enumerate(shape.dim):
+          if dim.size >= 0:
+            if i * self.hparams.max_output_size + j >= output_embed_dim:
+              break
+            op_output_shapes[idx,
+                             i * self.hparams.max_output_size + j] = dim.size
+    # adj for padding
+    op_adj = np.full(
+        [self.num_ops, self.hparams.adj_embed_dim], 0, dtype=np.float32)
+    for idx in adj_dict:
+      neighbors = adj_dict[int(idx)]
+      min_dim = min(self.hparams.adj_embed_dim, len(neighbors))
+      padding_size = self.hparams.adj_embed_dim - min_dim
+      neighbors = neighbors[:min_dim] + [0] * padding_size
+      op_adj[int(idx)] = neighbors
+
+    # op_embedding   starts here
+    op_embeddings = np.zeros(
+        [
+            self.num_ops,
+            1 + self.hparams.max_num_outputs * self.hparams.max_output_size +
+            self.hparams.adj_embed_dim
+        ],
+        dtype=np.float32)
+    for idx, op_name in enumerate(topo_order):
+      op_embeddings[idx] = np.concatenate(
+          (np.array([op_types[idx]]), op_output_shapes[idx], op_adj[int(idx)]))
+    self.op_embeddings = constant_op.constant(
+        op_embeddings, dtype=dtypes.float32)
+    if verbose:
+      print("num_ops = {}".format(self.num_ops))
+      print("num_types = {}".format(len(self.type_dict)))
+
+  def get_groupings(self, *args, **kwargs):
+    num_children = self.hparams.num_children
+    with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
+      grouping_actions_cache = variable_scope.get_local_variable(
+          "grouping_actions_cache",
+          initializer=init_ops.zeros_initializer,
+          dtype=dtypes.int32,
+          shape=[num_children, self.num_ops],
+          trainable=False)
+    input_layer = self.op_embeddings
+    input_layer = array_ops.expand_dims(input_layer, 0)
+    feed_ff_input_layer = array_ops.tile(input_layer, [num_children, 1, 1])
+    grouping_actions, grouping_log_probs = {}, {}
+    grouping_actions["sample"], grouping_log_probs[
+        "sample"] = self.make_grouping_predictions(feed_ff_input_layer)
+
+    grouping_actions["sample"] = state_ops.assign(grouping_actions_cache,
+                                                  grouping_actions["sample"])
+    self.grouping_actions_cache = grouping_actions_cache
+
+    return grouping_actions, grouping_log_probs
+
+  def make_grouping_predictions(self, input_layer, reuse=None):
+    """model that predicts grouping (grouping_actions).
+
+    Args:
+      input_layer: group_input_layer
+      reuse: reuse
+
+    Returns:
+       grouping_actions: actions
+       grouping_log_probs: log probabilities corresponding to actions
+    """
+    with variable_scope.variable_scope(self.hparams.name, reuse=True):
+      # input_layer: tensor of size [1, num_ops, hidden_size]
+      w_grouping_ff = variable_scope.get_variable("w_grouping_ff")
+      w_grouping_softmax = variable_scope.get_variable("w_grouping_softmax")
+
+    batch_size = array_ops.shape(input_layer)[0]
+    embedding_dim = array_ops.shape(input_layer)[2]
+
+    reshaped = array_ops.reshape(input_layer,
+                                 [batch_size * self.num_ops, embedding_dim])
+    ff_output = math_ops.matmul(reshaped, w_grouping_ff)
+    logits = math_ops.matmul(ff_output, w_grouping_softmax)
+    if self.hparams.logits_std_noise > 0:
+      num_in_logits = math_ops.cast(
+          array_ops.size(logits), dtype=dtypes.float32)
+      avg_norm = math_ops.divide(
+          linalg_ops.norm(logits), math_ops.sqrt(num_in_logits))
+      logits_noise = random_ops.random_normal(
+          array_ops.shape(logits),
+          stddev=self.hparams.logits_std_noise * avg_norm)
+      logits = control_flow_ops.cond(
+          self.global_step > self.hparams.stop_noise_step, lambda: logits,
+          lambda: logits + logits_noise)
+    logits = array_ops.reshape(logits,
+                               [batch_size * self.num_ops, self.num_groups])
+    actions = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
+    actions = math_ops.to_int32(actions)
+    actions = array_ops.reshape(actions, [batch_size, self.num_ops])
+    action_label = array_ops.reshape(actions, [-1])
+    log_probs = nn_ops.sparse_softmax_cross_entropy_with_logits(
+        logits=logits, labels=action_label)
+    log_probs = array_ops.reshape(log_probs, [batch_size, -1])
+    log_probs = math_ops.reduce_sum(log_probs, 1)
+    grouping_actions = actions
+    grouping_log_probs = log_probs
+    return grouping_actions, grouping_log_probs
+
+  def create_group_embeddings(self, grouping_actions, verbose=False):
+    """Approximating the blocks of a TF graph from a graph_def.
+
+    Args:
+      grouping_actions: grouping predictions
+      verbose: print stuffs.
+
+    Returns:
+      groups: list of groups.
+    """
+    if verbose:
+      print("Processing input_graph")
+
+    # TODO(azalia): Build inter-adjacencies dag matrix.
+    # record dag_matrix
+    dag_matrix = np.zeros([self.num_groups, self.num_groups], dtype=np.float32)
+    for op in self.important_ops:
+      topo_op_index = self.name_to_topo_order_index[op.name]
+      # TODO(agoldie) child_id
+      group_index = grouping_actions[0][topo_op_index]
+      for output_op in self.get_node_fanout(op):
+        if output_op.name not in self.important_op_names:
+          continue
+        output_group_index = grouping_actions[0][self.name_to_topo_order_index[
+            output_op.name]]
+        dag_matrix[group_index, output_group_index] += 1.0
+    num_connections = np.sum(dag_matrix)
+    num_intra_group_connections = dag_matrix.trace()
+    num_inter_group_connections = num_connections - num_intra_group_connections
+    if verbose:
+      print("grouping evaluation metric")
+      print("num_connections={} num_intra_group_connections={} "
+            "num_inter_group_connections={}").format(
+                num_connections, num_intra_group_connections,
+                num_inter_group_connections)
+    self.dag_matrix = dag_matrix
+
+    # output_shape
+    op_output_shapes = np.zeros(
+        [
+            len(self.important_ops),
+            self.hparams.max_num_outputs * self.hparams.max_output_size
+        ],
+        dtype=np.float32)
+
+    for idx, op in enumerate(self.important_ops):
+      for i, output_properties in enumerate(self.node_properties[op.name]):
+        if output_properties.shape.__str__() == "<unknown>":
+          continue
+        if i > self.hparams.max_num_outputs:
+          break
+        shape = output_properties.shape
+        for j, dim in enumerate(shape.dim):
+          if dim.size > 0:
+            k = i * self.hparams.max_output_size + j
+            if k >= self.hparams.max_num_outputs * self.hparams.max_output_size:
+              break
+            op_output_shapes[idx, k] = dim.size
+
+    # group_embedding
+    group_embedding = np.zeros(
+        [
+            self.num_groups, len(self.type_dict) +
+            self.hparams.max_num_outputs * self.hparams.max_output_size
+        ],
+        dtype=np.float32)
+    for op_index, op in enumerate(self.important_ops):
+      group_index = grouping_actions[0][self.name_to_topo_order_index[op.name]]
+      type_name = str(op.op)
+      type_index = self.type_dict[type_name]
+      group_embedding[group_index, type_index] += 1
+      group_embedding[group_index, :self.hparams.max_num_outputs * self.hparams.
+                      max_output_size] += (
+                          op_output_shapes[op_index])
+    grouping_adjacencies = np.concatenate(
+        [dag_matrix, np.transpose(dag_matrix)], axis=1)
+    group_embedding = np.concatenate(
+        [grouping_adjacencies, group_embedding], axis=1)
+    group_normalizer = np.amax(group_embedding, axis=1, keepdims=True)
+    group_embedding /= (group_normalizer + 1.0)
+    if verbose:
+      print("Finished Processing Input Graph")
+    return group_embedding
+
+  def get_placements(self, *args, **kwargs):
+    num_children = self.hparams.num_children
+    with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
+      actions_cache = variable_scope.get_local_variable(
+          "actions_cache",
+          initializer=init_ops.zeros_initializer,
+          dtype=dtypes.int32,
+          shape=[num_children, self.num_groups],
+          trainable=False)
+
+    x = array_ops.tile(self.seq2seq_input_layer, [num_children, 1, 1])
+    last_c, last_h, attn_mem = self.encode(x)
+    actions, log_probs = {}, {}
+    actions["sample"], log_probs["sample"] = (
+        self.decode(
+            x, last_c, last_h, attn_mem, mode="sample"))
+    actions["target"], log_probs["target"] = (
+        self.decode(
+            x,
+            last_c,
+            last_h,
+            attn_mem,
+            mode="target",
+            y=actions_cache))
+    actions["greedy"], log_probs["greedy"] = (
+        self.decode(
+            x, last_c, last_h, attn_mem, mode="greedy"))
+    actions["sample"] = control_flow_ops.cond(
+        self.global_step < self.hparams.stop_sampling,
+        lambda: state_ops.assign(actions_cache, actions["sample"]),
+        lambda: state_ops.assign(actions_cache, actions["target"]))
+    self.actions_cache = actions_cache
+
+    return actions, log_probs
+
+  def encode(self, x):
+    """Encoder using LSTM.
+
+    Args:
+      x: tensor of size [num_children, num_groups, embedding_size]
+
+    Returns:
+      last_c, last_h: tensors of size [num_children, hidden_size], the final
+        LSTM states
+      attn_mem: tensor of size [num_children, num_groups, hidden_size], the
+      attention
+        memory, i.e. concatenation of all hidden states, linearly transformed by
+        an attention matrix attn_w_1
+    """
+    if self.hparams.bi_lstm:
+      with variable_scope.variable_scope(self.hparams.name, reuse=True):
+        w_lstm_forward = variable_scope.get_variable("encoder_lstm_forward")
+        w_lstm_backward = variable_scope.get_variable("encoder_lstm_backward")
+        forget_bias = variable_scope.get_variable("encoder_forget_bias")
+        attn_w_1 = variable_scope.get_variable("attn_w_1")
+    else:
+      with variable_scope.variable_scope(self.hparams.name, reuse=True):
+        w_lstm = variable_scope.get_variable("encoder_lstm")
+        forget_bias = variable_scope.get_variable("encoder_forget_bias")
+        attn_w_1 = variable_scope.get_variable("attn_w_1")
+
+    embedding_size = array_ops.shape(x)[2]
+
+    signals = array_ops.split(x, self.num_groups, axis=1)
+    for i in range(len(signals)):
+      signals[i] = array_ops.reshape(
+          signals[i], [self.hparams.num_children, embedding_size])
+
+    if self.hparams.bi_lstm:
+
+      def body(i, prev_c_forward, prev_h_forward, prev_c_backward,
+               prev_h_backward):
+        """while loop for LSTM."""
+        signal_forward = signals[i]
+        next_c_forward, next_h_forward = lstm(signal_forward, prev_c_forward,
+                                              prev_h_forward, w_lstm_forward,
+                                              forget_bias)
+
+        signal_backward = signals[self.num_groups - 1 - i]
+        next_c_backward, next_h_backward = lstm(
+            signal_backward, prev_c_backward, prev_h_backward, w_lstm_backward,
+            forget_bias)
+
+        next_h = array_ops.concat([next_h_forward, next_h_backward], axis=1)
+        all_h.append(next_h)
+
+        return (next_c_forward, next_h_forward, next_c_backward,
+                next_h_backward)
+
+      c_forward = array_ops.zeros(
+          [self.hparams.num_children, self.hparams.hidden_size / 2],
+          dtype=dtypes.float32)
+      h_forward = array_ops.zeros(
+          [self.hparams.num_children, self.hparams.hidden_size / 2],
+          dtype=dtypes.float32)
+
+      c_backward = array_ops.zeros(
+          [self.hparams.num_children, self.hparams.hidden_size / 2],
+          dtype=dtypes.float32)
+      h_backward = array_ops.zeros(
+          [self.hparams.num_children, self.hparams.hidden_size / 2],
+          dtype=dtypes.float32)
+      all_h = []
+
+      for i in range(0, self.num_groups):
+        c_forward, h_forward, c_backward, h_backward = body(
+            i, c_forward, h_forward, c_backward, h_backward)
+
+      last_c = array_ops.concat([c_forward, c_backward], axis=1)
+      last_h = array_ops.concat([h_forward, h_backward], axis=1)
+      attn_mem = array_ops.stack(all_h)
+
+    else:
+
+      def body(i, prev_c, prev_h):
+        signal = signals[i]
+        next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias)
+        all_h.append(next_h)
+        return next_c, next_h
+
+      c = array_ops.zeros(
+          [self.hparams.num_children, self.hparams.hidden_size],
+          dtype=dtypes.float32)
+      h = array_ops.zeros(
+          [self.hparams.num_children, self.hparams.hidden_size],
+          dtype=dtypes.float32)
+      all_h = []
+
+      for i in range(0, self.num_groups):
+        c, h = body(i, c, h)
+
+      last_c = c
+      last_h = h
+      attn_mem = array_ops.stack(all_h)
+
+    attn_mem = array_ops.transpose(attn_mem, [1, 0, 2])
+    attn_mem = array_ops.reshape(
+        attn_mem,
+        [self.hparams.num_children * self.num_groups, self.hparams.hidden_size])
+    attn_mem = math_ops.matmul(attn_mem, attn_w_1)
+    attn_mem = array_ops.reshape(
+        attn_mem,
+        [self.hparams.num_children, self.num_groups, self.hparams.hidden_size])
+
+    return last_c, last_h, attn_mem
+
+  def decode(self,
+             x,
+             last_c,
+             last_h,
+             attn_mem,
+             mode="target",
+             y=None):
+    """Decoder using LSTM.
+
+    Args:
+      x: tensor of size [num_children, num_groups, embedding_size].
+      last_c: tensor of size [num_children, hidden_size], the final LSTM states
+          computed by self.encoder.
+      last_h: same as last_c.
+      attn_mem: tensor of size [num_children, num_groups, hidden_size].
+      mode: "target" or "sample".
+      y: tensor of size [num_children, num_groups], the device placements.
+
+    Returns:
+      actions: tensor of size [num_children, num_groups], the placements of
+          devices
+    """
+    with variable_scope.variable_scope(self.hparams.name, reuse=True):
+      w_lstm = variable_scope.get_variable("decoder_lstm")
+      forget_bias = variable_scope.get_variable("decoder_forget_bias")
+      device_embeddings = variable_scope.get_variable("device_embeddings")
+      device_softmax = variable_scope.get_variable("device_softmax")
+      device_go_embedding = variable_scope.get_variable("device_go_embedding")
+      attn_w_2 = variable_scope.get_variable("attn_w_2")
+      attn_v = variable_scope.get_variable("attn_v")
+
+    actions = tensor_array_ops.TensorArray(
+        dtypes.int32,
+        size=self.num_groups,
+        infer_shape=False,
+        clear_after_read=False)
+
+    # pylint: disable=unused-argument
+    def condition(i, *args):
+      return math_ops.less(i, self.num_groups)
+
+    # pylint: disable=missing-docstring
+    def body(i, prev_c, prev_h, actions, log_probs):
+      # pylint: disable=g-long-lambda
+      signal = control_flow_ops.cond(
+          math_ops.equal(i, 0),
+          lambda: array_ops.tile(device_go_embedding,
+                                 [self.hparams.num_children, 1]),
+          lambda: embedding_ops.embedding_lookup(device_embeddings,
+                                                 actions.read(i - 1))
+      )
+      if self.hparams.keep_prob is not None:
+        signal = nn_ops.dropout(signal, self.hparams.keep_prob)
+      next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias)
+      query = math_ops.matmul(next_h, attn_w_2)
+      query = array_ops.reshape(
+          query, [self.hparams.num_children, 1, self.hparams.hidden_size])
+      query = math_ops.tanh(query + attn_mem)
+      query = array_ops.reshape(query, [
+          self.hparams.num_children * self.num_groups, self.hparams.hidden_size
+      ])
+      query = math_ops.matmul(query, attn_v)
+      query = array_ops.reshape(query,
+                                [self.hparams.num_children, self.num_groups])
+      query = nn_ops.softmax(query)
+      query = array_ops.reshape(query,
+                                [self.hparams.num_children, self.num_groups, 1])
+      query = math_ops.reduce_sum(attn_mem * query, axis=1)
+      query = array_ops.concat([next_h, query], axis=1)
+      logits = math_ops.matmul(query, device_softmax)
+      logits /= self.hparams.temperature
+      if self.hparams.tanh_constant > 0:
+        logits = math_ops.tanh(logits) * self.hparams.tanh_constant
+      if self.hparams.logits_std_noise > 0:
+        num_in_logits = math_ops.cast(
+            array_ops.size(logits), dtype=dtypes.float32)
+        avg_norm = math_ops.divide(
+            linalg_ops.norm(logits), math_ops.sqrt(num_in_logits))
+        logits_noise = random_ops.random_normal(
+            array_ops.shape(logits),
+            stddev=self.hparams.logits_std_noise * avg_norm)
+        logits = control_flow_ops.cond(
+            self.global_step > self.hparams.stop_noise_step, lambda: logits,
+            lambda: logits + logits_noise)
+
+      if mode == "sample":
+        next_y = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
+      elif mode == "greedy":
+        next_y = math_ops.argmax(logits, 1)
+      elif mode == "target":
+        next_y = array_ops.slice(y, [0, i], [-1, 1])
+      else:
+        raise NotImplementedError
+      next_y = math_ops.to_int32(next_y)
+      next_y = array_ops.reshape(next_y, [self.hparams.num_children])
+      actions = actions.write(i, next_y)
+      log_probs += nn_ops.sparse_softmax_cross_entropy_with_logits(
+          logits=logits, labels=next_y)
+      return i + 1, next_c, next_h, actions, log_probs
+
+    loop_vars = [
+        constant_op.constant(0, dtype=dtypes.int32), last_c, last_h, actions,
+        array_ops.zeros([self.hparams.num_children], dtype=dtypes.float32)
+    ]
+    loop_outputs = control_flow_ops.while_loop(condition, body, loop_vars)
+
+    last_c = loop_outputs[-4]
+    last_h = loop_outputs[-3]
+    actions = loop_outputs[-2].stack()
+    actions = array_ops.transpose(actions, [1, 0])
+    log_probs = loop_outputs[-1]
+    return actions, log_probs
+
+  def eval_placement(self,
+                     sess,
+                     child_id=0,
+                     verbose=False):
+    grouping_actions, actions = sess.run([
+        self.grouping_actions_cache,
+        self.actions_cache
+    ])
+    grouping_actions = grouping_actions[child_id]
+    actions = actions[child_id]
+    if verbose:
+      global_step = sess.run(self.global_step)
+      if global_step % 100 == 0:
+        log_string = "op group assignments: "
+        for a in grouping_actions:
+          log_string += "{} ".format(a)
+        print(log_string[:-1])
+        log_string = "group device assignments: "
+        for a in actions:
+          log_string += "{} ".format(a)
+        print(log_string[:-1])
+
+    for op in self.important_ops:
+      topo_order_index = self.name_to_topo_order_index[op.name]
+      group_index = grouping_actions[topo_order_index]
+      op.device = self.devices[actions[group_index]].name
+    try:
+      _, run_time, _ = self.cluster.MeasureCosts(self.item)
+    except errors.ResourceExhaustedError:
+      run_time = self.hparams.failing_signal
+    return run_time
+
+  def update_reward(self,
+                    sess,
+                    run_time,
+                    child_id=0,
+                    verbose=False):
+    reward = self.compute_reward(run_time)
+    controller_ops = self.ops["controller"]
+    _, best_reward = sess.run(
+        [
+            controller_ops["reward"]["update"][child_id],
+            controller_ops["best_reward"]["update"][child_id]
+        ],
+        feed_dict={
+            controller_ops["reward"]["ph"][child_id]: reward,
+        })
+    if verbose:
+      print("run_time={:<.5f} reward={:<.5f} "
+            "best_reward={:<.5f}").format(run_time, reward, best_reward)
+
+    # Reward is a double, best_reward a float: allow for some slack in the
+    # comparison.
+    updated = abs(best_reward - reward) < 1e-6
+    return updated
+
+  def generate_grouping(self, sess):
+    controller_ops = self.ops["controller"]
+    grouping_actions = sess.run(controller_ops["grouping_y_preds"]["sample"])
+    return grouping_actions
+
+  def generate_placement(self, grouping, sess):
+    controller_ops = self.ops["controller"]
+    feed_seq2seq_input_dict = {}
+    feed_seq2seq_input_dict[self.seq2seq_input_layer] = np.expand_dims(
+        grouping, axis=0)
+    sess.run(
+        controller_ops["y_preds"]["sample"], feed_dict=feed_seq2seq_input_dict)
+
+  def process_reward(self, sess):
+    controller_ops = self.ops["controller"]
+    run_ops = [
+        controller_ops["loss"], controller_ops["lr"],
+        controller_ops["grad_norm"], controller_ops["grad_norms"],
+        controller_ops["train_op"]
+    ]
+    sess.run(run_ops)
+    sess.run(controller_ops["baseline_update"])
+
+  def _get_train_ops(self,
+                     loss,
+                     tf_variables,
+                     global_step,
+                     grad_bound=1.25,
+                     lr_init=1e-3,
+                     lr_dec=0.9,
+                     start_decay_step=10000,
+                     decay_steps=100,
+                     optimizer_type="adam"):
+    """Loss optimizer.
+
+    Args:
+      loss: scalar tf tensor
+      tf_variables: list of training variables, typically
+        tf.trainable_variables()
+      global_step: global_step
+      grad_bound: max gradient norm
+      lr_init: initial learning rate
+      lr_dec: leaning rate decay coefficient
+      start_decay_step: start decaying learning rate after this many steps
+      decay_steps: apply decay rate factor at this step intervals
+      optimizer_type: optimizer type should be either adam or sgd
+
+    Returns:
+      train_op: training op
+      learning_rate: scalar learning rate tensor
+      grad_norm: l2 norm of the gradient vector
+      all_grad_norms: l2 norm of each component
+    """
+    lr_gstep = global_step - start_decay_step
+
+    def f1():
+      return constant_op.constant(lr_init)
+
+    def f2():
+      return learning_rate_decay.exponential_decay(lr_init, lr_gstep,
+                                                   decay_steps, lr_dec, True)
+
+    learning_rate = control_flow_ops.cond(
+        math_ops.less(global_step, start_decay_step),
+        f1,
+        f2,
+        name="learning_rate")
+
+    if optimizer_type == "adam":
+      opt = adam.AdamOptimizer(learning_rate)
+    elif optimizer_type == "sgd":
+      opt = gradient_descent.GradientDescentOptimizer(learning_rate)
+    grads_and_vars = opt.compute_gradients(loss, tf_variables)
+    grad_norm = clip_ops.global_norm([g for g, v in grads_and_vars])
+    all_grad_norms = {}
+    clipped_grads = []
+    clipped_rate = math_ops.maximum(grad_norm / grad_bound, 1.0)
+    for g, v in grads_and_vars:
+      if g is not None:
+        if isinstance(g, tf_ops.IndexedSlices):
+          clipped = g.values / clipped_rate
+          norm_square = math_ops.reduce_sum(clipped * clipped)
+          clipped = tf_ops.IndexedSlices(clipped, g.indices)
+        else:
+          clipped = g / clipped_rate
+          norm_square = math_ops.reduce_sum(clipped * clipped)
+        all_grad_norms[v.name] = math_ops.sqrt(norm_square)
+        clipped_grads.append((clipped, v))
+
+    train_op = opt.apply_gradients(clipped_grads, global_step)
+    return train_op, learning_rate, grad_norm, all_grad_norms
+
+
+def lstm(x, prev_c, prev_h, w_lstm, forget_bias):
+  """LSTM cell.
+
+  Args:
+    x: tensors of size [num_children, hidden_size].
+    prev_c: tensors of size [num_children, hidden_size].
+    prev_h: same as prev_c.
+    w_lstm: .
+    forget_bias: .
+
+  Returns:
+    next_c:
+    next_h:
+  """
+  ifog = math_ops.matmul(array_ops.concat([x, prev_h], axis=1), w_lstm)
+  i, f, o, g = array_ops.split(ifog, 4, axis=1)
+  i = math_ops.sigmoid(i)
+  f = math_ops.sigmoid(f + forget_bias)
+  o = math_ops.sigmoid(o)
+  g = math_ops.tanh(g)
+  next_c = i * g + f * prev_c
+  next_h = o * math_ops.tanh(next_c)
+  return next_c, next_h
index d0fc1a0..9a84c60 100644 (file)
@@ -96,10 +96,10 @@ static GItem TF_NewItem(
   return GItem(item.release());
 }
 
-static std::vector<string> TF_IdentifyImportantOps(GItem item, bool sort_topologically,
+static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
                                                    TF_Status* status) {
   if (item.is_none()) {
-    return {};
+    Py_RETURN_NONE;
   }
 
   std::vector<const tensorflow::NodeDef*> main_ops = item->MainOpsFanin();
@@ -132,7 +132,13 @@ static std::vector<string> TF_IdentifyImportantOps(GItem item, bool sort_topolog
     }
   }
 
-  return ops;
+  PyGILState_STATE gstate = PyGILState_Ensure();
+  PyObject* result = PyList_New(ops.size());
+  for (int i = 0; i < ops.size(); ++i) {
+    PyList_SetItem(result, i, PyString_FromString(ops[i].c_str()));
+  }
+  PyGILState_Release(gstate);
+  return result;
 }
 
 static PyObject* TF_GetOpProperties(GItem item) {
@@ -305,7 +311,7 @@ static PyObject* TF_GetColocationGroups(GItem item) {
 static GItem TF_NewItem(
     const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation,
     bool ignore_user_placement, TF_Status* out_status);
-static std::vector<string> TF_IdentifyImportantOps(GItem item, bool sort_topologically,
-                                                   TF_Status* status);
+static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically,
+                                         TF_Status* status);
 static PyObject* TF_GetOpProperties(GItem item);
 static PyObject* TF_GetColocationGroups(GItem item);
index cd70e2f..7c3efd6 100644 (file)
@@ -56,7 +56,7 @@ class ItemTest(test.TestCase):
       mg = meta_graph.create_meta_graph_def(graph=g)
       grappler_item = item.Item(mg)
       op_list = grappler_item.IdentifyImportantOps()
-      self.assertItemsEqual([b'Const', b'Const_1', b'add'], op_list)
+      self.assertItemsEqual(['Const', 'Const_1', 'add'], op_list)
 
   def testOpProperties(self):
     with ops.Graph().as_default() as g: