class GraphProcessor {
public:
- GraphProcessor(const VirtualPlacer& virtual_placer,
+ GraphProcessor(const GraphProperties& graph_properties,
+ const VirtualPlacer& virtual_placer,
const std::unordered_set<string>& nodes_to_preserve,
GraphDef* graph, NodeMap* node_map)
- : virtual_placer_(virtual_placer),
+ : graph_properties_(graph_properties),
+ virtual_placer_(virtual_placer),
nodes_to_preserve_(nodes_to_preserve),
graph_(graph),
node_map_(node_map) {}
return strings::StrCat(base_name, "-", kSuffix);
}
+ const GraphProperties& graph_properties_;
const VirtualPlacer& virtual_placer_;
const std::unordered_set<string>& nodes_to_preserve_;
GraphDef* graph_;
struct OptimizeContext {
OptimizeContext(GraphDef* graph, NodeDef* node, NodeMap* node_map,
+ const GraphProperties& graph_properties,
const VirtualPlacer& virtual_placer,
const std::unordered_set<string>& nodes_to_preserve,
bool is_in_frame)
: graph(graph),
node(node),
node_map(node_map),
+ graph_properties(graph_properties),
virtual_placer(virtual_placer),
nodes_to_preserve(nodes_to_preserve),
is_in_frame(is_in_frame) {}
GraphDef* graph;
NodeDef* node;
NodeMap* node_map;
+ const GraphProperties& graph_properties;
const VirtualPlacer& virtual_placer;
const std::unordered_set<string>& nodes_to_preserve;
bool is_in_frame;
class NodeProcessor : public GraphProcessor {
public:
explicit NodeProcessor(const OptimizeContext& opt_cxt)
- : GraphProcessor(opt_cxt.virtual_placer, opt_cxt.nodes_to_preserve,
- opt_cxt.graph, opt_cxt.node_map),
+ : GraphProcessor(opt_cxt.graph_properties, opt_cxt.virtual_placer,
+ opt_cxt.nodes_to_preserve, opt_cxt.graph,
+ opt_cxt.node_map),
node_(opt_cxt.node),
is_in_frame_(opt_cxt.is_in_frame) {}
virtual ~NodeProcessor() {}
for (const auto& pos : input_pos) {
string node_name = LayoutOptimizerNode(
strings::StrCat(node_->name(), "-", pos, "-", kTransposeNHWCToNCHW));
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
+ DataType dtype =
+ graph_properties_.GetInputProperties(node_->name())[pos].dtype();
auto input_node = node_map_->GetNode(node_->input(pos));
TF_RETURN_IF_ERROR(HasAttribute(*input_node, "_output_shapes"));
string const_name = GetOrAddNodePermNHWCToNCHW(pos);
int output_pos;
ParseNodeName(node_->input(pos), &output_pos);
AddNodeTranspose(
- node_name, node_->input(pos), const_name,
- node_->attr().at("T").type(),
+ node_name, node_->input(pos), const_name, dtype,
input_node->attr().at("_output_shapes").list().shape(output_pos),
true);
node_map_->UpdateOutput(node_->input(pos), node_->name(), node_name);
string added_node_base_name =
strings::StrCat(node_->name(), "-", output_count, "-", i);
string added_node_name;
+ DataType dtype =
+ graph_properties_.GetOutputProperties(node_->name())[input_port]
+ .dtype();
if (op == "Transpose") {
added_node_name = LayoutOptimizerNode(strings::StrCat(
added_node_base_name, "-", kTransposeNCHWToNHWC));
- DataType dtype;
- if (IsAngle(*node_) || IsComplex(*node_) ||
- IsComplexAbs(*node_) || IsImag(*node_) || IsReal(*node_)) {
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "Tout"));
- dtype = node_->attr().at("Tout").type();
- } else if (IsBitcast(*node_)) {
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "type"));
- dtype = node_->attr().at("type").type();
- } else if (IsLogicalOp(*node_) || IsComparisonOp(*node_)) {
- dtype = DT_BOOL;
- } else {
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
- dtype = node_->attr().at("T").type();
- }
TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
AddNodeTranspose(
added_node_name, input, const_name, dtype,
} else if (op == "DataFormatVecPermute") {
added_node_name = LayoutOptimizerNode(strings::StrCat(
added_node_base_name, "-", kVecPermuteNCHWToNHWC));
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "out_type"));
- DataType dtype = (IsSplit(*node_) || IsSplitV(*node_))
- ? DT_INT32
- : node_->attr().at("out_type").type();
AddNodeDataFormatOp(added_node_name, input, op, dtype, false);
} else {
return errors::InvalidArgument("Unsupported op type: ", op);
class DataLayoutOptimizer : GraphProcessor {
public:
explicit DataLayoutOptimizer(
+ const GraphProperties& graph_properties,
const VirtualPlacer& virtual_placer,
const LayoutOptimizer::TuningConfig& config,
const std::unordered_set<string>& nodes_to_preserve, GraphDef* graph,
NodeMap* node_map)
- : GraphProcessor(virtual_placer, nodes_to_preserve, graph, node_map),
+ : GraphProcessor(graph_properties, virtual_placer, nodes_to_preserve,
+ graph, node_map),
config_(config) {}
Status Optimize() {
ops_format_supported.end()) {
auto node = graph_->mutable_node(i);
bool is_in_frame = !frames[node].empty();
- OptimizeContext opt_cxt(graph_, node, node_map_, virtual_placer_,
- nodes_to_preserve_, is_in_frame);
+ OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
+ virtual_placer_, nodes_to_preserve_,
+ is_in_frame);
std::unique_ptr<NodeProcessor> node_processor;
if (IsAvgPoolGrad(*node)) {
node_processor.reset(new AvgPoolGradProcessor(opt_cxt));
ops_format_agnostic.end()) {
auto node = graph_->mutable_node(i);
bool is_in_frame = !frames[node].empty();
- OptimizeContext opt_cxt(graph_, node, node_map_, virtual_placer_,
- nodes_to_preserve_, is_in_frame);
+ OptimizeContext opt_cxt(graph_, node, node_map_, graph_properties_,
+ virtual_placer_, nodes_to_preserve_,
+ is_in_frame);
std::unique_ptr<NodeProcessor> node_processor;
if (IsAddN(*node)) {
node_processor.reset(new AddNProcessor(opt_cxt));
return status;
}
NodeMap node_map(output);
- DataLayoutOptimizer layout_optimizer(*virtual_placer_, config,
- nodes_to_preserve_, output, &node_map);
+ DataLayoutOptimizer layout_optimizer(graph_properties, *virtual_placer_,
+ config, nodes_to_preserve_, output,
+ &node_map);
status = layout_optimizer.Optimize();
return status;
}
self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+ def testCast(self):
+ if test.is_gpu_available(cuda_only=True):
+ random_seed.set_random_seed(0)
+ x = random_ops.truncated_normal([1, 784], seed=0)
+ conv = _two_layer_model(x)
+ cast = math_ops.cast(conv, dtype='bool')
+ output = array_ops.identity(cast)
+
+ with session.Session() as sess:
+ output_val_ref = sess.run(output)
+
+ with session.Session(config=_get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(output, run_metadata=metadata)
+
+ nodes = []
+ num_transposes = 0
+ for node in metadata.cost_graph.node:
+ if _is_transpose(node.name):
+ num_transposes += 1
+ nodes.append(node.name)
+
+ # Four transposes were initially added in the Expand phase of
+ # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+ expected_num_transposes = 2
+ self.assertEqual(expected_num_transposes, num_transposes)
+ self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes)
+ self._assert_trans_nchw_to_nhwc('Cast-0-0', nodes)
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
def testReduceSumAlongHWC(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)