super(ConvKFCBasicFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
+ inputs, grads_list = self._process_data(grads_list)
+
# Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
+ self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
self._strides)
- inputs, grads_list = self._process_data(grads_list)
-
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
(inputs, self._filter_shape, self._padding, self._strides,
inputs, grads_list = self._process_data(grads_list)
# Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs.shape.as_list(),
+ self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
self._strides)
self._input_factor = self._layer_collection.make_or_get_factor(