Fix type attribute name for Cast.
authorYao Zhang <yaozhang@google.com>
Tue, 9 Jan 2018 19:31:05 +0000 (11:31 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 9 Jan 2018 19:39:24 +0000 (11:39 -0800)
PiperOrigin-RevId: 181348431

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/layout_optimizer.cc
tensorflow/python/grappler/layout_optimizer_test.py

index 63ad6d9221b109f21f69c305e8a87aab2cc77aef..29e8e29ace0ad6b4d23be598fbbcf2fa0a477601 100644 (file)
@@ -66,6 +66,8 @@ bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
 
 bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
 
+bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
+
 bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
 
 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
index 31067391921fff2b0ff61c8478853b3098ddbe2c..6b5e7e3391e956613b8c68cdd54051e788293be7 100644 (file)
@@ -36,6 +36,7 @@ bool IsBetainc(const NodeDef& node);
 bool IsBiasAdd(const NodeDef& node);
 bool IsBiasAddGrad(const NodeDef& node);
 bool IsBitcast(const NodeDef& node);
+bool IsCast(const NodeDef& node);
 bool IsComplex(const NodeDef& node);
 bool IsComplexAbs(const NodeDef& node);
 bool IsConj(const NodeDef& node);
index 7723ca5defdce9b523426063c64ea7068457fa10..b10f5c0f628b76e6b79f0090223351298c6589ad 100644 (file)
@@ -369,10 +369,12 @@ std::vector<int> DataInputPos(const NodeDef& node) {
 
 class GraphProcessor {
  public:
-  GraphProcessor(const VirtualPlacer& virtual_placer,
+  GraphProcessor(const GraphProperties& graph_properties,
+                 const VirtualPlacer& virtual_placer,
                  const std::unordered_set<string>& nodes_to_preserve,
                  GraphDef* graph, NodeMap* node_map)
-      : virtual_placer_(virtual_placer),
+      : graph_properties_(graph_properties),
+        virtual_placer_(virtual_placer),
         nodes_to_preserve_(nodes_to_preserve),
         graph_(graph),
         node_map_(node_map) {}
@@ -432,6 +434,7 @@ class GraphProcessor {
     return strings::StrCat(base_name, "-", kSuffix);
   }
 
+  const GraphProperties& graph_properties_;
   const VirtualPlacer& virtual_placer_;
   const std::unordered_set<string>& nodes_to_preserve_;
   GraphDef* graph_;
@@ -440,18 +443,21 @@ class GraphProcessor {
 
 struct OptimizeContext {
   OptimizeContext(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+                  const GraphProperties& graph_properties,
                   const VirtualPlacer& virtual_placer,
                   const std::unordered_set<string>& nodes_to_preserve,
                   bool is_in_frame)
       : graph(graph),
         node(node),
         node_map(node_map),
+        graph_properties(graph_properties),
         virtual_placer(virtual_placer),
         nodes_to_preserve(nodes_to_preserve),
         is_in_frame(is_in_frame) {}
   GraphDef* graph;
   NodeDef* node;
   NodeMap* node_map;
+  const GraphProperties& graph_properties;
   const VirtualPlacer& virtual_placer;
   const std::unordered_set<string>& nodes_to_preserve;
   bool is_in_frame;
@@ -460,8 +466,9 @@ struct OptimizeContext {
 class NodeProcessor : public GraphProcessor {
  public:
   explicit NodeProcessor(const OptimizeContext& opt_cxt)
-      : GraphProcessor(opt_cxt.virtual_placer, opt_cxt.nodes_to_preserve,
-                       opt_cxt.graph, opt_cxt.node_map),
+      : GraphProcessor(opt_cxt.graph_properties, opt_cxt.virtual_placer,
+                       opt_cxt.nodes_to_preserve, opt_cxt.graph,
+                       opt_cxt.node_map),
         node_(opt_cxt.node),
         is_in_frame_(opt_cxt.is_in_frame) {}
   virtual ~NodeProcessor() {}
@@ -607,15 +614,15 @@ class NodeProcessor : public GraphProcessor {
     for (const auto& pos : input_pos) {
       string node_name = LayoutOptimizerNode(
           strings::StrCat(node_->name(), "-", pos, "-", kTransposeNHWCToNCHW));
-      TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
+      DataType dtype =
+          graph_properties_.GetInputProperties(node_->name())[pos].dtype();
       auto input_node = node_map_->GetNode(node_->input(pos));
       TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
       string const_name = GetOrAddNodePermNHWCToNCHW(pos);
       int output_pos;
       ParseNodeName(node_->input(pos), &output_pos);
       AddNodeTranspose(
-          node_name, node_->input(pos), const_name,
-          node_->attr().at("T").type(),
+          node_name, node_->input(pos), const_name, dtype,
           input_node->attr().at("_output_shapes").list().shape(output_pos),
           true);
       node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name);
@@ -644,23 +651,12 @@ class NodeProcessor : public GraphProcessor {
             string added_node_base_name =
                 strings::StrCat(node_->name(), "-", output_count, "-", i);
             string added_node_name;
+            DataType dtype =
+                graph_properties_.GetOutputProperties(node_->name())[input_port]
+                    .dtype();
             if (op == "Transpose") {
               added_node_name = LayoutOptimizerNode(strings::StrCat(
                   added_node_base_name, "-", kTransposeNCHWToNHWC));
-              DataType dtype;
-              if (IsAngle(*node_) || IsComplex(*node_) ||
-                  IsComplexAbs(*node_) || IsImag(*node_) || IsReal(*node_)) {
-                TF_RETURN_IF_ERROR(HasAttribute(*node_, "Tout"));
-                dtype = node_->attr().at("Tout").type();
-              } else if (IsBitcast(*node_)) {
-                TF_RETURN_IF_ERROR(HasAttribute(*node_, "type"));
-                dtype = node_->attr().at("type").type();
-              } else if (IsLogicalOp(*node_) || IsComparisonOp(*node_)) {
-                dtype = DT_BOOL;
-              } else {
-                TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
-                dtype = node_->attr().at("T").type();
-              }
               TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
               AddNodeTranspose(
                   added_node_name, input, const_name, dtype,
@@ -669,10 +665,6 @@ class NodeProcessor : public GraphProcessor {
             } else if (op == "DataFormatVecPermute") {
               added_node_name = LayoutOptimizerNode(strings::StrCat(
                   added_node_base_name, "-", kVecPermuteNCHWToNHWC));
-              TF_RETURN_IF_ERROR(HasAttribute(*node_, "out_type"));
-              DataType dtype = (IsSplit(*node_) || IsSplitV(*node_))
-                                   ? DT_INT32
-                                   : node_->attr().at("out_type").type();
               AddNodeDataFormatOp(added_node_name, input, op, dtype, false);
             } else {
               return errors::InvalidArgument("Unsupported op type: ", op);
@@ -1817,11 +1809,13 @@ class TileProcessor : public AgnosticNodeProcessor {
 class DataLayoutOptimizer : GraphProcessor {
  public:
   explicit DataLayoutOptimizer(
+      const GraphProperties& graph_properties,
       const VirtualPlacer& virtual_placer,
       const LayoutOptimizer::TuningConfig& config,
       const std::unordered_set<string>& nodes_to_preserve, GraphDef* graph,
       NodeMap* node_map)
-      : GraphProcessor(virtual_placer, nodes_to_preserve, graph, node_map),
+      : GraphProcessor(graph_properties, virtual_placer, nodes_to_preserve,
+                       graph, node_map),
         config_(config) {}
 
   Status Optimize() {
@@ -1862,8 +1856,9 @@ class DataLayoutOptimizer : GraphProcessor {
           ops_format_supported.end()) {
         auto node = graph_->mutable_node(i);
         bool is_in_frame = !frames[node].empty();
-        OptimizeContext opt_cxt(graph_, node, node_map_, virtual_placer_,
-                                nodes_to_preserve_, is_in_frame);
+        OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
+                                virtual_placer_, nodes_to_preserve_,
+                                is_in_frame);
         std::unique_ptr<NodeProcessor> node_processor;
         if (IsAvgPoolGrad(*node)) {
           node_processor.reset(new AvgPoolGradProcessor(opt_cxt));
@@ -1911,8 +1906,9 @@ class DataLayoutOptimizer : GraphProcessor {
             ops_format_agnostic.end()) {
           auto node = graph_->mutable_node(i);
           bool is_in_frame = !frames[node].empty();
-          OptimizeContext opt_cxt(graph_, node, node_map_, virtual_placer_,
-                                  nodes_to_preserve_, is_in_frame);
+          OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
+                                  virtual_placer_, nodes_to_preserve_,
+                                  is_in_frame);
           std::unique_ptr<NodeProcessor> node_processor;
           if (IsAddN(*node)) {
             node_processor.reset(new AddNProcessor(opt_cxt));
@@ -2057,8 +2053,9 @@ Status LayoutOptimizer::Tune(const GrapplerItem& item,
     return status;
   }
   NodeMap node_map(output);
-  DataLayoutOptimizer layout_optimizer(*virtual_placer_, config,
-                                       nodes_to_preserve_, output, &node_map);
+  DataLayoutOptimizer layout_optimizer(graph_properties, *virtual_placer_,
+                                       config, nodes_to_preserve_, output,
+                                       &node_map);
   status = layout_optimizer.Optimize();
   return status;
 }
index 961a4c9f4e1cf9fdcd198ffbb154d3297f0ed65b..1572fb9651e90b98a73d51fd4e9d4dc0ea31ea9f 100644 (file)
@@ -405,6 +405,36 @@ class LayoutOptimizerTest(test.TestCase):
       self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
       self.assertAllClose(output_val_ref, output_val, atol=1e-3)
 
+  def testCast(self):
+    if test.is_gpu_available(cuda_only=True):
+      random_seed.set_random_seed(0)
+      x = random_ops.truncated_normal([1, 784], seed=0)
+      conv = _two_layer_model(x)
+      cast = math_ops.cast(conv, dtype='bool')
+      output = array_ops.identity(cast)
+
+      with session.Session() as sess:
+        output_val_ref = sess.run(output)
+
+      with session.Session(config=_get_config()) as sess:
+        metadata = config_pb2.RunMetadata()
+        output_val = sess.run(output, run_metadata=metadata)
+
+      nodes = []
+      num_transposes = 0
+      for node in metadata.cost_graph.node:
+        if _is_transpose(node.name):
+          num_transposes += 1
+        nodes.append(node.name)
+
+      # Four transposes were initially added in the Expand phase of
+      # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+      expected_num_transposes = 2
+      self.assertEqual(expected_num_transposes, num_transposes)
+      self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
+      self._assert_trans_nchw_to_nhwc('Cast-0-0', nodes)
+      self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
   def testReduceSumAlongHWC(self):
     if test.is_gpu_available(cuda_only=True):
       random_seed.set_random_seed(0)