std::vector<int> GetInputPos() const override { return {0, 1, 2}; }
};
+class SelectProcessor : public AgnosticNodeProcessor {
+ public:
+ explicit SelectProcessor(const OptimizeContext& opt_cxt)
+ : AgnosticNodeProcessor(opt_cxt) {}
+
+ protected:
+ std::vector<int> GetInputPos() const override {
+ auto input0 = node_map_->GetNode(node_->input(0));
+ int input0_port;
+ ParseNodeName(node_->input(0), &input0_port);
+ // Input 0 could be a scalar, a vector with size matching the first
+ // dimension of input 1 and 2, or must have the same shape as input 1 and 2.
+ if (IsPortDimsFour(*input0, input0_port)) {
+ return {0, 1, 2};
+ } else {
+ return {1, 2};
+ }
+ }
+};
+
class UnaryGradProcessor : public AgnosticNodeProcessor {
public:
explicit UnaryGradProcessor(const OptimizeContext& opt_cxt)
std::unique_ptr<NodeProcessor> node_processor;
if (IsAddN(*node)) {
node_processor.reset(new AddNProcessor(opt_cxt));
- } else if (IsBetainc(*node) || IsSelect(*node)) {
+ } else if (IsBetainc(*node)) {
node_processor.reset(new TernaryOpProcessor(opt_cxt));
} else if (IsBinaryOp(*node)) {
node_processor.reset(new BinaryOpProcessor(opt_cxt));
node_processor.reset(new ReduceProcessor(opt_cxt));
} else if (IsReverseV2(*node)) {
node_processor.reset(new ReverseProcessor(opt_cxt));
+ } else if (IsSelect(*node)) {
+ node_processor.reset(new SelectProcessor(opt_cxt));
} else if (IsSlice(*node)) {
node_processor.reset(new SliceProcessor(opt_cxt));
} else if (IsStridedSlice(*node)) {
self.assertIn('LayoutOptimizerDimMapNHWCToNCHW_ReverseV2_1', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
- def testTernaryOp(self):
+ def testSelectOp(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)
x = random_ops.truncated_normal([1, 784], seed=0)
self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Select-0-0', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+ def testSelectOpScalarCondition(self):
+ if test.is_gpu_available(cuda_only=True):
+ random_seed.set_random_seed(0)
+ x = random_ops.truncated_normal([1, 784], seed=0)
+ conv = _two_layer_model(x)
+ add = math_ops.add(conv, conv)
+ condition = constant_op.constant(True)
+ select = gen_math_ops._select(condition, conv, add)
+ output = array_ops.identity(select)
+
+ with session.Session() as sess:
+ output_val_ref = sess.run(output)
+
+ with session.Session(config=_get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(output, run_metadata=metadata)
+
+ nodes = []
+ num_transposes = 0
+ for node in metadata.cost_graph.node:
+ if node.name.startswith('LayoutOptimizerTranspose'):
+ num_transposes += 1
+ nodes.append(node.name)
+
+ expected_num_transposes = 2
+ self.assertEqual(expected_num_transposes, num_transposes)
+ self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
+ self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Select-0-0', nodes)
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
def testPadWithNonConstPaddings(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)