From 71593602d95385fbd8c3dde361dab09d381b5ac6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 27 Mar 2018 16:43:48 -0700 Subject: [PATCH] Fixed a bug in ConvKFCBasicMultiIndepFB introduced in the last CL PiperOrigin-RevId: 190695737 --- tensorflow/contrib/kfac/python/ops/fisher_blocks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index b04bf76..e0d9cb5 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -861,12 +861,12 @@ class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): super(ConvKFCBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(), + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), self._strides) - inputs, grads_list = self._process_data(grads_list) - self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, (inputs, self._filter_shape, self._padding, self._strides, @@ -1391,7 +1391,7 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs.shape.as_list(), + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), self._strides) self._input_factor = self._layer_collection.make_or_get_factor( -- 2.7.4