Fixed a bug in ConvKFCBasicMultiIndepFB introduced in the last CL
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Mar 2018 23:43:48 +0000 (16:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 27 Mar 2018 23:48:25 +0000 (16:48 -0700)
PiperOrigin-RevId: 190695737

tensorflow/contrib/kfac/python/ops/fisher_blocks.py

index b04bf76..e0d9cb5 100644 (file)
@@ -861,12 +861,12 @@ class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
     super(ConvKFCBasicFB, self).__init__(layer_collection)
 
   def instantiate_factors(self, grads_list, damping):
+    inputs, grads_list = self._process_data(grads_list)
+
     # Infer number of locations upon which convolution is applied.
-    self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
+    self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
                                              self._strides)
 
-    inputs, grads_list = self._process_data(grads_list)
-
     self._input_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.ConvInputKroneckerFactor,
         (inputs, self._filter_shape, self._padding, self._strides,
@@ -1391,7 +1391,7 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse,
     inputs, grads_list = self._process_data(grads_list)
 
     # Infer number of locations upon which convolution is applied.
-    self._num_locations = num_conv_locations(inputs.shape.as_list(),
+    self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
                                              self._strides)
 
     self._input_factor = self._layer_collection.make_or_get_factor(