--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""HierarchicalController Class.
+
+The HierarchicalController encompasses the entire lifecycle of training the
+device placement policy, including generating op embeddings, getting groups for
+each op, placing those groups and running the predicted placements.
+
+Different assignment models can inherit from this class.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+import six
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.grappler.controller import Controller
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.summary import summary
+from tensorflow.python.training import adam
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import learning_rate_decay
+from tensorflow.python.training import training_util
+
+
+class PlacerParams(object):
+ """Class to hold a set of placement parameters as name-value pairs.
+
+ A typical usage is as follows:
+
+ ```python
+ # Create a PlacerParams object specifying names and values of the model
+ # parameters:
+ params = PlacerParams(hidden_size=128, decay_steps=50)
+
+ # The parameters are available as attributes of the PlacerParams object:
+ hparams.hidden_size ==> 128
+ hparams.decay_steps ==> 50
+ ```
+
+ """
+
+ def __init__(self, **kwargs):
+ """Create an instance of `PlacerParams` from keyword arguments.
+
+ The keyword arguments specify name-values pairs for the parameters.
+ The parameter types are inferred from the type of the values passed.
+
+ The parameter names are added as attributes of `PlacerParams` object,
+ and they can be accessed directly with the dot notation `params._name_`.
+
+ Example:
+
+ ```python
+ # Define 1 parameter: 'hidden_size'
+ params = PlacerParams(hidden_size=128)
+ params.hidden_size ==> 128
+ ```
+
+ Args:
+ **kwargs: Key-value pairs where the key is the parameter name and
+ the value is the value for the parameter.
+ """
+ for name, value in six.iteritems(kwargs):
+ self.add_param(name, value)
+
+ def add_param(self, name, value):
+ """Adds {name, value} pair to hyperparameters.
+
+ Args:
+ name: Name of the hyperparameter.
+ value: Value of the hyperparameter. Can be one of the following types:
+ int, float, string, int list, float list, or string list.
+
+ Raises:
+ ValueError: if one of the arguments is invalid.
+ """
+ # Keys in kwargs are unique, but 'name' could be the name of a pre-existing
+ # attribute of this object. In that case we refuse to use it as a
+ # parameter name.
+ if getattr(self, name, None) is not None:
+ raise ValueError("Parameter name is reserved: %s" % name)
+ setattr(self, name, value)
+
+
+def hierarchical_controller_hparams():
+ """Hyperparameters for hierarchical planner."""
+ return PlacerParams(
+ hidden_size=512,
+ forget_bias_init=1.0,
+ temperature=1.0,
+ logits_std_noise=0.5,
+ stop_noise_step=750,
+ decay_steps=50,
+ max_num_outputs=5,
+ max_output_size=5,
+ tanh_constant=1.0,
+ adj_embed_dim=20,
+ grouping_hidden_size=64,
+ num_groups=None,
+ bi_lstm=True,
+ failing_signal=100,
+ stop_sampling=500,
+ start_with_failing_signal=True,
+ always_update_baseline=False,
+ bl_dec=0.9,
+ grad_bound=1.0,
+ lr=0.1,
+ lr_dec=0.95,
+ start_decay_step=400,
+ optimizer_type="adam",
+ stop_updating_after_steps=1000,
+ name="hierarchical_controller",
+ keep_prob=1.0,
+ reward_function="sqrt",
+ seed=1234,
+ # distributed training params
+ num_children=1)
+
+
+class HierarchicalController(Controller):
+ """HierarchicalController class."""
+
+ def __init__(self, hparams, item, cluster, controller_id=0):
+ """HierarchicalController class initializer.
+
+ Args:
+ hparams: All hyper-parameters.
+ item: The metagraph to place.
+ cluster: The cluster of hardware devices to optimize for.
+ controller_id: the id of the controller in a multi-controller setup.
+ """
+ super(HierarchicalController, self).__init__(item, cluster)
+ self.ctrl_id = controller_id
+ self.hparams = hparams
+
+ if self.hparams.num_groups is None:
+ self.num_groups = min(256, 20 * self.num_devices)
+ else:
+ self.num_groups = self.hparams.num_groups
+
+ # creates self.op_embeddings and self.type_dict
+ self.create_op_embeddings(verbose=False)
+ # TODO(azalia) clean up embedding/group_embedding_size names
+ self.group_emb_size = (
+ 2 * self.num_groups + len(self.type_dict) +
+ self.hparams.max_num_outputs * self.hparams.max_output_size)
+ self.embedding_size = self.group_emb_size
+ self.initializer = init_ops.glorot_uniform_initializer(
+ seed=self.hparams.seed)
+
+ with variable_scope.variable_scope(
+ self.hparams.name,
+ initializer=self.initializer,
+ reuse=variable_scope.AUTO_REUSE):
+ # define parameters of feedforward
+ variable_scope.get_variable("w_grouping_ff", [
+ 1 + self.hparams.max_num_outputs * self.hparams.max_output_size +
+ self.hparams.adj_embed_dim, self.hparams.grouping_hidden_size
+ ])
+ variable_scope.get_variable(
+ "w_grouping_softmax",
+ [self.hparams.grouping_hidden_size, self.num_groups])
+ if self.hparams.bi_lstm:
+ variable_scope.get_variable("encoder_lstm_forward", [
+ self.embedding_size + self.hparams.hidden_size / 2,
+ 2 * self.hparams.hidden_size
+ ])
+ variable_scope.get_variable("encoder_lstm_backward", [
+ self.embedding_size + self.hparams.hidden_size / 2,
+ 2 * self.hparams.hidden_size
+ ])
+ variable_scope.get_variable(
+ "device_embeddings", [self.num_devices, self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "decoder_lstm",
+ [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "device_softmax", [2 * self.hparams.hidden_size, self.num_devices])
+ variable_scope.get_variable("device_go_embedding",
+ [1, self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "encoder_forget_bias",
+ shape=1,
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(
+ self.hparams.forget_bias_init))
+ variable_scope.get_variable(
+ "decoder_forget_bias",
+ shape=1,
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(
+ self.hparams.forget_bias_init))
+ variable_scope.get_variable(
+ "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size])
+ variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1])
+
+ else:
+ variable_scope.get_variable("encoder_lstm", [
+ self.embedding_size + self.hparams.hidden_size,
+ 4 * self.hparams.hidden_size
+ ])
+ variable_scope.get_variable(
+ "device_embeddings", [self.num_devices, self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "decoder_lstm",
+ [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "device_softmax", [2 * self.hparams.hidden_size, self.num_devices])
+ variable_scope.get_variable("device_go_embedding",
+ [1, self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "encoder_forget_bias",
+ shape=1,
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(
+ self.hparams.forget_bias_init))
+ variable_scope.get_variable(
+ "decoder_forget_bias",
+ shape=1,
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(
+ self.hparams.forget_bias_init))
+ variable_scope.get_variable(
+ "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size])
+ variable_scope.get_variable(
+ "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size])
+ variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1])
+ seq2seq_input_layer = array_ops.placeholder_with_default(
+ array_ops.zeros([1, self.num_groups, self.group_emb_size],
+ dtypes.float32),
+ shape=(1, self.num_groups, self.group_emb_size))
+ self.seq2seq_input_layer = seq2seq_input_layer
+
+ def compute_reward(self, run_time):
+ if self.hparams.reward_function == "id":
+ reward = run_time
+ elif self.hparams.reward_function == "sqrt":
+ reward = math.sqrt(run_time)
+ elif self.hparams.reward_function == "log":
+ reward = math.log1p(run_time)
+ else:
+ raise NotImplementedError(
+ "Unrecognized reward function '%s', consider your "
+ "--reward_function flag value." % self.hparams.reward_function)
+ return reward
+
+ def build_controller(self):
+ """RL optimization interface.
+
+ Returns:
+ ops: A dictionary holding handles of the model used for training.
+ """
+
+ self._global_step = training_util.get_or_create_global_step()
+ ops = {}
+ ops["loss"] = 0
+
+ failing_signal = self.compute_reward(self.hparams.failing_signal)
+
+ ctr = {}
+
+ with tf_ops.name_scope("controller_{}".format(self.ctrl_id)):
+ with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
+ ctr["reward"] = {"value": [], "ph": [], "update": []}
+ ctr["ready"] = {"value": [], "ph": [], "update": []}
+ ctr["best_reward"] = {"value": [], "update": []}
+ for i in range(self.hparams.num_children):
+ reward_value = variable_scope.get_local_variable(
+ "reward_{}".format(i),
+ initializer=0.0,
+ dtype=dtypes.float32,
+ trainable=False)
+ reward_ph = array_ops.placeholder(
+ dtypes.float32, shape=(), name="reward_ph_{}".format(i))
+ reward_update = state_ops.assign(
+ reward_value, reward_ph, use_locking=True)
+ ctr["reward"]["value"].append(reward_value)
+ ctr["reward"]["ph"].append(reward_ph)
+ ctr["reward"]["update"].append(reward_update)
+ best_reward = variable_scope.get_local_variable(
+ "best_reward_{}".format(i),
+ initializer=failing_signal,
+ dtype=dtypes.float32,
+ trainable=False)
+ ctr["best_reward"]["value"].append(best_reward)
+ ctr["best_reward"]["update"].append(
+ state_ops.assign(best_reward,
+ math_ops.minimum(best_reward, reward_update)))
+
+ ready_value = variable_scope.get_local_variable(
+ "ready_{}".format(i),
+ initializer=True,
+ dtype=dtypes.bool,
+ trainable=False)
+ ready_ph = array_ops.placeholder(
+ dtypes.bool, shape=(), name="ready_ph_{}".format(i))
+ ready_update = state_ops.assign(
+ ready_value, ready_ph, use_locking=True)
+ ctr["ready"]["value"].append(ready_value)
+ ctr["ready"]["ph"].append(ready_ph)
+ ctr["ready"]["update"].append(ready_update)
+
+ ctr["grouping_y_preds"], ctr["grouping_log_probs"] = self.get_groupings()
+ summary.histogram(
+ "grouping_actions",
+ array_ops.slice(ctr["grouping_y_preds"]["sample"], [0, 0],
+ [1, array_ops.shape(self.op_embeddings)[0]]))
+
+ with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
+ ctr["baseline"] = variable_scope.get_local_variable(
+ "baseline",
+ initializer=failing_signal
+ if self.hparams.start_with_failing_signal else 0.0,
+ dtype=dtypes.float32,
+ trainable=False)
+
+ new_baseline = self.hparams.bl_dec * ctr["baseline"] + (
+ 1 - self.hparams.bl_dec) * math_ops.reduce_mean(
+ ctr["reward"]["value"])
+ if not self.hparams.always_update_baseline:
+ baseline_mask = math_ops.less(ctr["reward"]["value"], failing_signal)
+ selected_reward = array_ops.boolean_mask(ctr["reward"]["value"],
+ baseline_mask)
+ selected_baseline = control_flow_ops.cond(
+ math_ops.reduce_any(baseline_mask),
+ lambda: math_ops.reduce_mean(selected_reward),
+ lambda: constant_op.constant(0, dtype=dtypes.float32))
+ ctr["pos_reward"] = selected_baseline
+ pos_ = math_ops.less(
+ constant_op.constant(0, dtype=dtypes.float32), selected_baseline)
+ selected_baseline = self.hparams.bl_dec * ctr["baseline"] + (
+ 1 - self.hparams.bl_dec) * selected_baseline
+ selected_baseline = control_flow_ops.cond(
+ pos_, lambda: selected_baseline, lambda: ctr["baseline"])
+ new_baseline = control_flow_ops.cond(
+ math_ops.less(self.global_step,
+ self.hparams.stop_updating_after_steps),
+ lambda: new_baseline, lambda: selected_baseline)
+ ctr["baseline_update"] = state_ops.assign(
+ ctr["baseline"], new_baseline, use_locking=True)
+
+ ctr["y_preds"], ctr["log_probs"] = self.get_placements()
+ summary.histogram("actions", ctr["y_preds"]["sample"])
+ mask = math_ops.less(ctr["reward"]["value"], failing_signal)
+ ctr["loss"] = ctr["reward"]["value"] - ctr["baseline"]
+ ctr["loss"] *= (
+ ctr["log_probs"]["sample"] + ctr["grouping_log_probs"]["sample"])
+
+ selected_loss = array_ops.boolean_mask(ctr["loss"], mask)
+ selected_loss = control_flow_ops.cond(
+ math_ops.reduce_any(mask),
+ lambda: math_ops.reduce_mean(-selected_loss),
+ lambda: constant_op.constant(0, dtype=dtypes.float32))
+
+ ctr["loss"] = control_flow_ops.cond(
+ math_ops.less(self.global_step,
+ self.hparams.stop_updating_after_steps),
+ lambda: math_ops.reduce_mean(-ctr["loss"]), lambda: selected_loss)
+
+ ctr["reward_s"] = math_ops.reduce_mean(ctr["reward"]["value"])
+ summary.scalar("loss", ctr["loss"])
+ summary.scalar("avg_reward", ctr["reward_s"])
+ summary.scalar("best_reward_so_far", best_reward)
+ summary.scalar(
+ "advantage",
+ math_ops.reduce_mean(ctr["reward"]["value"] - ctr["baseline"]))
+
+ with variable_scope.variable_scope(
+ "optimizer", reuse=variable_scope.AUTO_REUSE):
+ (ctr["train_op"], ctr["lr"], ctr["grad_norm"],
+ ctr["grad_norms"]) = self._get_train_ops(
+ ctr["loss"],
+ tf_ops.get_collection(tf_ops.GraphKeys.TRAINABLE_VARIABLES),
+ self.global_step,
+ grad_bound=self.hparams.grad_bound,
+ lr_init=self.hparams.lr,
+ lr_dec=self.hparams.lr_dec,
+ start_decay_step=self.hparams.start_decay_step,
+ decay_steps=self.hparams.decay_steps,
+ optimizer_type=self.hparams.optimizer_type)
+
+ summary.scalar("gradnorm", ctr["grad_norm"])
+ summary.scalar("lr", ctr["lr"])
+ ctr["summary"] = summary.merge_all()
+ ops["controller"] = ctr
+
+ self.ops = ops
+ return ops
+
+ @property
+ def global_step(self):
+ return self._global_step
+
+ def create_op_embeddings(self, verbose=False):
+ if verbose:
+ print("process input graph for op embeddings")
+ self.num_ops = len(self.important_ops)
+ # topological sort of important nodes
+ topo_order = [op.name for op in self.important_ops]
+
+ # create index to name for topologicaly sorted important nodes
+ name_to_topo_order_index = {}
+ for idx, x in enumerate(topo_order):
+ name_to_topo_order_index[x] = idx
+ self.name_to_topo_order_index = name_to_topo_order_index
+
+ # create adj matrix
+ adj_dict = {}
+ for idx, op in enumerate(self.important_ops):
+ for output_op in self.get_node_fanout(op):
+ output_op_name = output_op.name
+ if output_op_name in self.important_op_names:
+ if name_to_topo_order_index[op.name] not in adj_dict:
+ adj_dict[name_to_topo_order_index[op.name]] = []
+ adj_dict[name_to_topo_order_index[op.name]].extend(
+ [name_to_topo_order_index[output_op_name], 1])
+ if output_op_name not in adj_dict:
+ adj_dict[name_to_topo_order_index[output_op_name]] = []
+ adj_dict[name_to_topo_order_index[output_op_name]].extend(
+ [name_to_topo_order_index[op.name], -1])
+
+ # get op_type op_output_shape, and adj info
+ output_embed_dim = (self.hparams.max_num_outputs *
+ self.hparams.max_output_size)
+
+ # TODO(bsteiner): don't filter based on used ops so that we can generalize
+ # to models that use other types of ops.
+ used_ops = set()
+ for node in self.important_ops:
+ op_type = str(node.op)
+ used_ops.add(op_type)
+
+ self.type_dict = {}
+ for op_type in self.cluster.ListAvailableOps():
+ if op_type in used_ops:
+ self.type_dict[op_type] = len(self.type_dict)
+
+ op_types = np.zeros([self.num_ops], dtype=np.int32)
+ op_output_shapes = np.full(
+ [self.num_ops, output_embed_dim], -1.0, dtype=np.float32)
+ for idx, node in enumerate(self.important_ops):
+ op_types[idx] = self.type_dict[node.op]
+ # output shape
+ op_name = node.name
+ for i, output_prop in enumerate(self.node_properties[op_name]):
+ if output_prop.shape.__str__() == "<unknown>":
+ continue
+ shape = output_prop.shape
+ for j, dim in enumerate(shape.dim):
+ if dim.size >= 0:
+ if i * self.hparams.max_output_size + j >= output_embed_dim:
+ break
+ op_output_shapes[idx,
+ i * self.hparams.max_output_size + j] = dim.size
+ # adj for padding
+ op_adj = np.full(
+ [self.num_ops, self.hparams.adj_embed_dim], 0, dtype=np.float32)
+ for idx in adj_dict:
+ neighbors = adj_dict[int(idx)]
+ min_dim = min(self.hparams.adj_embed_dim, len(neighbors))
+ padding_size = self.hparams.adj_embed_dim - min_dim
+ neighbors = neighbors[:min_dim] + [0] * padding_size
+ op_adj[int(idx)] = neighbors
+
+ # op_embedding starts here
+ op_embeddings = np.zeros(
+ [
+ self.num_ops,
+ 1 + self.hparams.max_num_outputs * self.hparams.max_output_size +
+ self.hparams.adj_embed_dim
+ ],
+ dtype=np.float32)
+ for idx, op_name in enumerate(topo_order):
+ op_embeddings[idx] = np.concatenate(
+ (np.array([op_types[idx]]), op_output_shapes[idx], op_adj[int(idx)]))
+ self.op_embeddings = constant_op.constant(
+ op_embeddings, dtype=dtypes.float32)
+ if verbose:
+ print("num_ops = {}".format(self.num_ops))
+ print("num_types = {}".format(len(self.type_dict)))
+
+ def get_groupings(self, *args, **kwargs):
+ num_children = self.hparams.num_children
+ with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
+ grouping_actions_cache = variable_scope.get_local_variable(
+ "grouping_actions_cache",
+ initializer=init_ops.zeros_initializer,
+ dtype=dtypes.int32,
+ shape=[num_children, self.num_ops],
+ trainable=False)
+ input_layer = self.op_embeddings
+ input_layer = array_ops.expand_dims(input_layer, 0)
+ feed_ff_input_layer = array_ops.tile(input_layer, [num_children, 1, 1])
+ grouping_actions, grouping_log_probs = {}, {}
+ grouping_actions["sample"], grouping_log_probs[
+ "sample"] = self.make_grouping_predictions(feed_ff_input_layer)
+
+ grouping_actions["sample"] = state_ops.assign(grouping_actions_cache,
+ grouping_actions["sample"])
+ self.grouping_actions_cache = grouping_actions_cache
+
+ return grouping_actions, grouping_log_probs
+
+ def make_grouping_predictions(self, input_layer, reuse=None):
+ """model that predicts grouping (grouping_actions).
+
+ Args:
+ input_layer: group_input_layer
+ reuse: reuse
+
+ Returns:
+ grouping_actions: actions
+ grouping_log_probs: log probabilities corresponding to actions
+ """
+ with variable_scope.variable_scope(self.hparams.name, reuse=True):
+ # input_layer: tensor of size [1, num_ops, hidden_size]
+ w_grouping_ff = variable_scope.get_variable("w_grouping_ff")
+ w_grouping_softmax = variable_scope.get_variable("w_grouping_softmax")
+
+ batch_size = array_ops.shape(input_layer)[0]
+ embedding_dim = array_ops.shape(input_layer)[2]
+
+ reshaped = array_ops.reshape(input_layer,
+ [batch_size * self.num_ops, embedding_dim])
+ ff_output = math_ops.matmul(reshaped, w_grouping_ff)
+ logits = math_ops.matmul(ff_output, w_grouping_softmax)
+ if self.hparams.logits_std_noise > 0:
+ num_in_logits = math_ops.cast(
+ array_ops.size(logits), dtype=dtypes.float32)
+ avg_norm = math_ops.divide(
+ linalg_ops.norm(logits), math_ops.sqrt(num_in_logits))
+ logits_noise = random_ops.random_normal(
+ array_ops.shape(logits),
+ stddev=self.hparams.logits_std_noise * avg_norm)
+ logits = control_flow_ops.cond(
+ self.global_step > self.hparams.stop_noise_step, lambda: logits,
+ lambda: logits + logits_noise)
+ logits = array_ops.reshape(logits,
+ [batch_size * self.num_ops, self.num_groups])
+ actions = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
+ actions = math_ops.to_int32(actions)
+ actions = array_ops.reshape(actions, [batch_size, self.num_ops])
+ action_label = array_ops.reshape(actions, [-1])
+ log_probs = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ logits=logits, labels=action_label)
+ log_probs = array_ops.reshape(log_probs, [batch_size, -1])
+ log_probs = math_ops.reduce_sum(log_probs, 1)
+ grouping_actions = actions
+ grouping_log_probs = log_probs
+ return grouping_actions, grouping_log_probs
+
+ def create_group_embeddings(self, grouping_actions, verbose=False):
+ """Approximating the blocks of a TF graph from a graph_def.
+
+ Args:
+ grouping_actions: grouping predictions
+ verbose: print stuffs.
+
+ Returns:
+ groups: list of groups.
+ """
+ if verbose:
+ print("Processing input_graph")
+
+ # TODO(azalia): Build inter-adjacencies dag matrix.
+ # record dag_matrix
+ dag_matrix = np.zeros([self.num_groups, self.num_groups], dtype=np.float32)
+ for op in self.important_ops:
+ topo_op_index = self.name_to_topo_order_index[op.name]
+ # TODO(agoldie) child_id
+ group_index = grouping_actions[0][topo_op_index]
+ for output_op in self.get_node_fanout(op):
+ if output_op.name not in self.important_op_names:
+ continue
+ output_group_index = grouping_actions[0][self.name_to_topo_order_index[
+ output_op.name]]
+ dag_matrix[group_index, output_group_index] += 1.0
+ num_connections = np.sum(dag_matrix)
+ num_intra_group_connections = dag_matrix.trace()
+ num_inter_group_connections = num_connections - num_intra_group_connections
+ if verbose:
+ print("grouping evaluation metric")
+ print("num_connections={} num_intra_group_connections={} "
+ "num_inter_group_connections={}").format(
+ num_connections, num_intra_group_connections,
+ num_inter_group_connections)
+ self.dag_matrix = dag_matrix
+
+ # output_shape
+ op_output_shapes = np.zeros(
+ [
+ len(self.important_ops),
+ self.hparams.max_num_outputs * self.hparams.max_output_size
+ ],
+ dtype=np.float32)
+
+ for idx, op in enumerate(self.important_ops):
+ for i, output_properties in enumerate(self.node_properties[op.name]):
+ if output_properties.shape.__str__() == "<unknown>":
+ continue
+ if i > self.hparams.max_num_outputs:
+ break
+ shape = output_properties.shape
+ for j, dim in enumerate(shape.dim):
+ if dim.size > 0:
+ k = i * self.hparams.max_output_size + j
+ if k >= self.hparams.max_num_outputs * self.hparams.max_output_size:
+ break
+ op_output_shapes[idx, k] = dim.size
+
+ # group_embedding
+ group_embedding = np.zeros(
+ [
+ self.num_groups, len(self.type_dict) +
+ self.hparams.max_num_outputs * self.hparams.max_output_size
+ ],
+ dtype=np.float32)
+ for op_index, op in enumerate(self.important_ops):
+ group_index = grouping_actions[0][self.name_to_topo_order_index[op.name]]
+ type_name = str(op.op)
+ type_index = self.type_dict[type_name]
+ group_embedding[group_index, type_index] += 1
+ group_embedding[group_index, :self.hparams.max_num_outputs * self.hparams.
+ max_output_size] += (
+ op_output_shapes[op_index])
+ grouping_adjacencies = np.concatenate(
+ [dag_matrix, np.transpose(dag_matrix)], axis=1)
+ group_embedding = np.concatenate(
+ [grouping_adjacencies, group_embedding], axis=1)
+ group_normalizer = np.amax(group_embedding, axis=1, keepdims=True)
+ group_embedding /= (group_normalizer + 1.0)
+ if verbose:
+ print("Finished Processing Input Graph")
+ return group_embedding
+
+ def get_placements(self, *args, **kwargs):
+ num_children = self.hparams.num_children
+ with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
+ actions_cache = variable_scope.get_local_variable(
+ "actions_cache",
+ initializer=init_ops.zeros_initializer,
+ dtype=dtypes.int32,
+ shape=[num_children, self.num_groups],
+ trainable=False)
+
+ x = array_ops.tile(self.seq2seq_input_layer, [num_children, 1, 1])
+ last_c, last_h, attn_mem = self.encode(x)
+ actions, log_probs = {}, {}
+ actions["sample"], log_probs["sample"] = (
+ self.decode(
+ x, last_c, last_h, attn_mem, mode="sample"))
+ actions["target"], log_probs["target"] = (
+ self.decode(
+ x,
+ last_c,
+ last_h,
+ attn_mem,
+ mode="target",
+ y=actions_cache))
+ actions["greedy"], log_probs["greedy"] = (
+ self.decode(
+ x, last_c, last_h, attn_mem, mode="greedy"))
+ actions["sample"] = control_flow_ops.cond(
+ self.global_step < self.hparams.stop_sampling,
+ lambda: state_ops.assign(actions_cache, actions["sample"]),
+ lambda: state_ops.assign(actions_cache, actions["target"]))
+ self.actions_cache = actions_cache
+
+ return actions, log_probs
+
+ def encode(self, x):
+ """Encoder using LSTM.
+
+ Args:
+ x: tensor of size [num_children, num_groups, embedding_size]
+
+ Returns:
+ last_c, last_h: tensors of size [num_children, hidden_size], the final
+ LSTM states
+ attn_mem: tensor of size [num_children, num_groups, hidden_size], the
+ attention
+ memory, i.e. concatenation of all hidden states, linearly transformed by
+ an attention matrix attn_w_1
+ """
+ if self.hparams.bi_lstm:
+ with variable_scope.variable_scope(self.hparams.name, reuse=True):
+ w_lstm_forward = variable_scope.get_variable("encoder_lstm_forward")
+ w_lstm_backward = variable_scope.get_variable("encoder_lstm_backward")
+ forget_bias = variable_scope.get_variable("encoder_forget_bias")
+ attn_w_1 = variable_scope.get_variable("attn_w_1")
+ else:
+ with variable_scope.variable_scope(self.hparams.name, reuse=True):
+ w_lstm = variable_scope.get_variable("encoder_lstm")
+ forget_bias = variable_scope.get_variable("encoder_forget_bias")
+ attn_w_1 = variable_scope.get_variable("attn_w_1")
+
+ embedding_size = array_ops.shape(x)[2]
+
+ signals = array_ops.split(x, self.num_groups, axis=1)
+ for i in range(len(signals)):
+ signals[i] = array_ops.reshape(
+ signals[i], [self.hparams.num_children, embedding_size])
+
+ if self.hparams.bi_lstm:
+
+ def body(i, prev_c_forward, prev_h_forward, prev_c_backward,
+ prev_h_backward):
+ """while loop for LSTM."""
+ signal_forward = signals[i]
+ next_c_forward, next_h_forward = lstm(signal_forward, prev_c_forward,
+ prev_h_forward, w_lstm_forward,
+ forget_bias)
+
+ signal_backward = signals[self.num_groups - 1 - i]
+ next_c_backward, next_h_backward = lstm(
+ signal_backward, prev_c_backward, prev_h_backward, w_lstm_backward,
+ forget_bias)
+
+ next_h = array_ops.concat([next_h_forward, next_h_backward], axis=1)
+ all_h.append(next_h)
+
+ return (next_c_forward, next_h_forward, next_c_backward,
+ next_h_backward)
+
+ c_forward = array_ops.zeros(
+ [self.hparams.num_children, self.hparams.hidden_size / 2],
+ dtype=dtypes.float32)
+ h_forward = array_ops.zeros(
+ [self.hparams.num_children, self.hparams.hidden_size / 2],
+ dtype=dtypes.float32)
+
+ c_backward = array_ops.zeros(
+ [self.hparams.num_children, self.hparams.hidden_size / 2],
+ dtype=dtypes.float32)
+ h_backward = array_ops.zeros(
+ [self.hparams.num_children, self.hparams.hidden_size / 2],
+ dtype=dtypes.float32)
+ all_h = []
+
+ for i in range(0, self.num_groups):
+ c_forward, h_forward, c_backward, h_backward = body(
+ i, c_forward, h_forward, c_backward, h_backward)
+
+ last_c = array_ops.concat([c_forward, c_backward], axis=1)
+ last_h = array_ops.concat([h_forward, h_backward], axis=1)
+ attn_mem = array_ops.stack(all_h)
+
+ else:
+
+ def body(i, prev_c, prev_h):
+ signal = signals[i]
+ next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias)
+ all_h.append(next_h)
+ return next_c, next_h
+
+ c = array_ops.zeros(
+ [self.hparams.num_children, self.hparams.hidden_size],
+ dtype=dtypes.float32)
+ h = array_ops.zeros(
+ [self.hparams.num_children, self.hparams.hidden_size],
+ dtype=dtypes.float32)
+ all_h = []
+
+ for i in range(0, self.num_groups):
+ c, h = body(i, c, h)
+
+ last_c = c
+ last_h = h
+ attn_mem = array_ops.stack(all_h)
+
+ attn_mem = array_ops.transpose(attn_mem, [1, 0, 2])
+ attn_mem = array_ops.reshape(
+ attn_mem,
+ [self.hparams.num_children * self.num_groups, self.hparams.hidden_size])
+ attn_mem = math_ops.matmul(attn_mem, attn_w_1)
+ attn_mem = array_ops.reshape(
+ attn_mem,
+ [self.hparams.num_children, self.num_groups, self.hparams.hidden_size])
+
+ return last_c, last_h, attn_mem
+
+ def decode(self,
+ x,
+ last_c,
+ last_h,
+ attn_mem,
+ mode="target",
+ y=None):
+ """Decoder using LSTM.
+
+ Args:
+ x: tensor of size [num_children, num_groups, embedding_size].
+ last_c: tensor of size [num_children, hidden_size], the final LSTM states
+ computed by self.encoder.
+ last_h: same as last_c.
+ attn_mem: tensor of size [num_children, num_groups, hidden_size].
+ mode: "target" or "sample".
+ y: tensor of size [num_children, num_groups], the device placements.
+
+ Returns:
+ actions: tensor of size [num_children, num_groups], the placements of
+ devices
+ """
+ with variable_scope.variable_scope(self.hparams.name, reuse=True):
+ w_lstm = variable_scope.get_variable("decoder_lstm")
+ forget_bias = variable_scope.get_variable("decoder_forget_bias")
+ device_embeddings = variable_scope.get_variable("device_embeddings")
+ device_softmax = variable_scope.get_variable("device_softmax")
+ device_go_embedding = variable_scope.get_variable("device_go_embedding")
+ attn_w_2 = variable_scope.get_variable("attn_w_2")
+ attn_v = variable_scope.get_variable("attn_v")
+
+ actions = tensor_array_ops.TensorArray(
+ dtypes.int32,
+ size=self.num_groups,
+ infer_shape=False,
+ clear_after_read=False)
+
+ # pylint: disable=unused-argument
+ def condition(i, *args):
+ return math_ops.less(i, self.num_groups)
+
+ # pylint: disable=missing-docstring
+ def body(i, prev_c, prev_h, actions, log_probs):
+ # pylint: disable=g-long-lambda
+ signal = control_flow_ops.cond(
+ math_ops.equal(i, 0),
+ lambda: array_ops.tile(device_go_embedding,
+ [self.hparams.num_children, 1]),
+ lambda: embedding_ops.embedding_lookup(device_embeddings,
+ actions.read(i - 1))
+ )
+ if self.hparams.keep_prob is not None:
+ signal = nn_ops.dropout(signal, self.hparams.keep_prob)
+ next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias)
+ query = math_ops.matmul(next_h, attn_w_2)
+ query = array_ops.reshape(
+ query, [self.hparams.num_children, 1, self.hparams.hidden_size])
+ query = math_ops.tanh(query + attn_mem)
+ query = array_ops.reshape(query, [
+ self.hparams.num_children * self.num_groups, self.hparams.hidden_size
+ ])
+ query = math_ops.matmul(query, attn_v)
+ query = array_ops.reshape(query,
+ [self.hparams.num_children, self.num_groups])
+ query = nn_ops.softmax(query)
+ query = array_ops.reshape(query,
+ [self.hparams.num_children, self.num_groups, 1])
+ query = math_ops.reduce_sum(attn_mem * query, axis=1)
+ query = array_ops.concat([next_h, query], axis=1)
+ logits = math_ops.matmul(query, device_softmax)
+ logits /= self.hparams.temperature
+ if self.hparams.tanh_constant > 0:
+ logits = math_ops.tanh(logits) * self.hparams.tanh_constant
+ if self.hparams.logits_std_noise > 0:
+ num_in_logits = math_ops.cast(
+ array_ops.size(logits), dtype=dtypes.float32)
+ avg_norm = math_ops.divide(
+ linalg_ops.norm(logits), math_ops.sqrt(num_in_logits))
+ logits_noise = random_ops.random_normal(
+ array_ops.shape(logits),
+ stddev=self.hparams.logits_std_noise * avg_norm)
+ logits = control_flow_ops.cond(
+ self.global_step > self.hparams.stop_noise_step, lambda: logits,
+ lambda: logits + logits_noise)
+
+ if mode == "sample":
+ next_y = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
+ elif mode == "greedy":
+ next_y = math_ops.argmax(logits, 1)
+ elif mode == "target":
+ next_y = array_ops.slice(y, [0, i], [-1, 1])
+ else:
+ raise NotImplementedError
+ next_y = math_ops.to_int32(next_y)
+ next_y = array_ops.reshape(next_y, [self.hparams.num_children])
+ actions = actions.write(i, next_y)
+ log_probs += nn_ops.sparse_softmax_cross_entropy_with_logits(
+ logits=logits, labels=next_y)
+ return i + 1, next_c, next_h, actions, log_probs
+
+ loop_vars = [
+ constant_op.constant(0, dtype=dtypes.int32), last_c, last_h, actions,
+ array_ops.zeros([self.hparams.num_children], dtype=dtypes.float32)
+ ]
+ loop_outputs = control_flow_ops.while_loop(condition, body, loop_vars)
+
+ last_c = loop_outputs[-4]
+ last_h = loop_outputs[-3]
+ actions = loop_outputs[-2].stack()
+ actions = array_ops.transpose(actions, [1, 0])
+ log_probs = loop_outputs[-1]
+ return actions, log_probs
+
+ def eval_placement(self,
+ sess,
+ child_id=0,
+ verbose=False):
+ grouping_actions, actions = sess.run([
+ self.grouping_actions_cache,
+ self.actions_cache
+ ])
+ grouping_actions = grouping_actions[child_id]
+ actions = actions[child_id]
+ if verbose:
+ global_step = sess.run(self.global_step)
+ if global_step % 100 == 0:
+ log_string = "op group assignments: "
+ for a in grouping_actions:
+ log_string += "{} ".format(a)
+ print(log_string[:-1])
+ log_string = "group device assignments: "
+ for a in actions:
+ log_string += "{} ".format(a)
+ print(log_string[:-1])
+
+ for op in self.important_ops:
+ topo_order_index = self.name_to_topo_order_index[op.name]
+ group_index = grouping_actions[topo_order_index]
+ op.device = self.devices[actions[group_index]].name
+ try:
+ _, run_time, _ = self.cluster.MeasureCosts(self.item)
+ except errors.ResourceExhaustedError:
+ run_time = self.hparams.failing_signal
+ return run_time
+
+ def update_reward(self,
+ sess,
+ run_time,
+ child_id=0,
+ verbose=False):
+ reward = self.compute_reward(run_time)
+ controller_ops = self.ops["controller"]
+ _, best_reward = sess.run(
+ [
+ controller_ops["reward"]["update"][child_id],
+ controller_ops["best_reward"]["update"][child_id]
+ ],
+ feed_dict={
+ controller_ops["reward"]["ph"][child_id]: reward,
+ })
+ if verbose:
+ print("run_time={:<.5f} reward={:<.5f} "
+ "best_reward={:<.5f}").format(run_time, reward, best_reward)
+
+ # Reward is a double, best_reward a float: allow for some slack in the
+ # comparison.
+ updated = abs(best_reward - reward) < 1e-6
+ return updated
+
+ def generate_grouping(self, sess):
+ controller_ops = self.ops["controller"]
+ grouping_actions = sess.run(controller_ops["grouping_y_preds"]["sample"])
+ return grouping_actions
+
+ def generate_placement(self, grouping, sess):
+ controller_ops = self.ops["controller"]
+ feed_seq2seq_input_dict = {}
+ feed_seq2seq_input_dict[self.seq2seq_input_layer] = np.expand_dims(
+ grouping, axis=0)
+ sess.run(
+ controller_ops["y_preds"]["sample"], feed_dict=feed_seq2seq_input_dict)
+
+ def process_reward(self, sess):
+ controller_ops = self.ops["controller"]
+ run_ops = [
+ controller_ops["loss"], controller_ops["lr"],
+ controller_ops["grad_norm"], controller_ops["grad_norms"],
+ controller_ops["train_op"]
+ ]
+ sess.run(run_ops)
+ sess.run(controller_ops["baseline_update"])
+
+ def _get_train_ops(self,
+ loss,
+ tf_variables,
+ global_step,
+ grad_bound=1.25,
+ lr_init=1e-3,
+ lr_dec=0.9,
+ start_decay_step=10000,
+ decay_steps=100,
+ optimizer_type="adam"):
+ """Loss optimizer.
+
+ Args:
+ loss: scalar tf tensor
+ tf_variables: list of training variables, typically
+ tf.trainable_variables()
+ global_step: global_step
+ grad_bound: max gradient norm
+ lr_init: initial learning rate
+ lr_dec: leaning rate decay coefficient
+ start_decay_step: start decaying learning rate after this many steps
+ decay_steps: apply decay rate factor at this step intervals
+ optimizer_type: optimizer type should be either adam or sgd
+
+ Returns:
+ train_op: training op
+ learning_rate: scalar learning rate tensor
+ grad_norm: l2 norm of the gradient vector
+ all_grad_norms: l2 norm of each component
+ """
+ lr_gstep = global_step - start_decay_step
+
+ def f1():
+ return constant_op.constant(lr_init)
+
+ def f2():
+ return learning_rate_decay.exponential_decay(lr_init, lr_gstep,
+ decay_steps, lr_dec, True)
+
+ learning_rate = control_flow_ops.cond(
+ math_ops.less(global_step, start_decay_step),
+ f1,
+ f2,
+ name="learning_rate")
+
+ if optimizer_type == "adam":
+ opt = adam.AdamOptimizer(learning_rate)
+ elif optimizer_type == "sgd":
+ opt = gradient_descent.GradientDescentOptimizer(learning_rate)
+ grads_and_vars = opt.compute_gradients(loss, tf_variables)
+ grad_norm = clip_ops.global_norm([g for g, v in grads_and_vars])
+ all_grad_norms = {}
+ clipped_grads = []
+ clipped_rate = math_ops.maximum(grad_norm / grad_bound, 1.0)
+ for g, v in grads_and_vars:
+ if g is not None:
+ if isinstance(g, tf_ops.IndexedSlices):
+ clipped = g.values / clipped_rate
+ norm_square = math_ops.reduce_sum(clipped * clipped)
+ clipped = tf_ops.IndexedSlices(clipped, g.indices)
+ else:
+ clipped = g / clipped_rate
+ norm_square = math_ops.reduce_sum(clipped * clipped)
+ all_grad_norms[v.name] = math_ops.sqrt(norm_square)
+ clipped_grads.append((clipped, v))
+
+ train_op = opt.apply_gradients(clipped_grads, global_step)
+ return train_op, learning_rate, grad_norm, all_grad_norms
+
+
+def lstm(x, prev_c, prev_h, w_lstm, forget_bias):
+ """LSTM cell.
+
+ Args:
+ x: tensors of size [num_children, hidden_size].
+ prev_c: tensors of size [num_children, hidden_size].
+ prev_h: same as prev_c.
+ w_lstm: .
+ forget_bias: .
+
+ Returns:
+ next_c:
+ next_h:
+ """
+ ifog = math_ops.matmul(array_ops.concat([x, prev_h], axis=1), w_lstm)
+ i, f, o, g = array_ops.split(ifog, 4, axis=1)
+ i = math_ops.sigmoid(i)
+ f = math_ops.sigmoid(f + forget_bias)
+ o = math_ops.sigmoid(o)
+ g = math_ops.tanh(g)
+ next_c = i * g + f * prev_c
+ next_h = o * math_ops.tanh(next_c)
+ return next_c, next_h