From: A. Unique TensorFlower Date: Tue, 27 Mar 2018 23:43:48 +0000 (-0700) Subject: Fixed a bug in ConvKFCBasicMultiIndepFB introduced in the last CL X-Git-Tag: tflite-v0.1.7~67^2^2~84 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=71593602d95385fbd8c3dde361dab09d381b5ac6;p=platform%2Fupstream%2Ftensorflow.git Fixed a bug in ConvKFCBasicMultiIndepFB introduced in the last CL PiperOrigin-RevId: 190695737 --- diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index b04bf76..e0d9cb5 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -861,12 +861,12 @@ class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): super(ConvKFCBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(), + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), self._strides) - inputs, grads_list = self._process_data(grads_list) - self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, (inputs, self._filter_shape, self._padding, self._strides, @@ -1391,7 +1391,7 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs.shape.as_list(), + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), self._strides) self._input_factor = self._layer_collection.make_or_get_factor(