Fix kernel creation bug, due to constant folding always use CPU.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 28 Apr 2018 06:35:42 +0000 (23:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 28 Apr 2018 06:38:42 +0000 (23:38 -0700)
PiperOrigin-RevId: 194636076

tensorflow/core/grappler/optimizers/layout_optimizer_test.cc

index fc87f69..dad49cd 100644 (file)
@@ -108,10 +108,8 @@ class LayoutOptimizerTest : public GrapplerTest {
 
     TensorShape filter_shape(
         {filter_size, filter_size, input_depth, filter_count});
-    Tensor filter_data(DT_FLOAT, filter_shape);
-    test::FillIota<float>(&filter_data, 1.0f);
     Output filter =
-        ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
+        ops::Variable(s->WithOpName("Filter"), filter_shape, DT_FLOAT);
 
     int output_height = input_height;
     int output_width = input_width;
@@ -143,6 +141,10 @@ class LayoutOptimizerTest : public GrapplerTest {
     return tensor;
   }
 
+  TensorShape GetAttrShape(const NodeDef& node) {
+    return TensorShape(node.attr().at({"shape"}).shape());
+  }
+
   Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) {
     int batch_size = 16;
     int input_height = 8;
@@ -200,9 +202,12 @@ TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
   test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes);
 
   if (gpu_available_) {
+    TensorShape filter_shape = GetAttrShape(*node_map.GetNode("Filter"));
+    Tensor filter_data = GenerateRandomTensor<DT_FLOAT>(filter_shape);
     std::vector<string> fetch = {"Fetch"};
-    auto tensors_expected = EvaluateNodes(item.graph, fetch);
-    auto tensors = EvaluateNodes(output, fetch);
+    auto tensors_expected =
+        EvaluateNodes(item.graph, fetch, {{"Filter", filter_data}});
+    auto tensors = EvaluateNodes(output, fetch, {{"Filter", filter_data}});
     EXPECT_EQ(1, tensors_expected.size());
     EXPECT_EQ(1, tensors.size());
     test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);