K-FAC: FisherBlocks for tf.nn.{depthwise_conv2d, separable_conv2d, convolution}.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Mar 2018 21:23:49 +0000 (14:23 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 21:30:43 +0000 (14:30 -0700)
PiperOrigin-RevId: 188778072

tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
tensorflow/contrib/kfac/python/ops/fisher_blocks.py
tensorflow/contrib/kfac/python/ops/fisher_factors.py
tensorflow/contrib/kfac/python/ops/layer_collection.py
tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
tensorflow/contrib/kfac/python/ops/utils.py
tensorflow/contrib/kfac/python/ops/utils_lib.py

index c9c0f8e..b70c700 100644 (file)
@@ -764,6 +764,54 @@ class ConvDiagonalFBTest(test.TestCase):
     return multiply_result, multiply_inverse_result
 
 
+class DepthwiseConvKFCBasicFBTest(test.TestCase):
+
+  def testInstantiateFactors(self):
+    with ops.Graph().as_default():
+      random_seed.set_random_seed(200)
+      params = random_ops.random_normal((3, 3, 8, 2))
+      inputs = random_ops.random_normal((32, 5, 5, 8))
+      outputs = random_ops.random_normal((32, 5, 5, 16))
+      layer_collection = lc.LayerCollection()
+      block = fb.DepthwiseConvKFCBasicFB(
+          layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
+      block.register_additional_minibatch(inputs, outputs)
+      grads = outputs**2
+      block.instantiate_factors(([grads],), 0.5)
+
+  def testMultiplyInverse(self):
+    with ops.Graph().as_default(), self.test_session() as sess:
+      random_seed.set_random_seed(200)
+      params = random_ops.random_normal((3, 3, 8, 2))
+      inputs = random_ops.random_normal((32, 5, 5, 8))
+      outputs = random_ops.random_normal((32, 5, 5, 16))
+      layer_collection = lc.LayerCollection()
+      block = fb.DepthwiseConvKFCBasicFB(
+          layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
+      block.register_additional_minibatch(inputs, outputs)
+      grads = outputs**2
+      block.instantiate_factors(([grads],), 0.5)
+      block._input_factor.instantiate_cov_variables()
+      block._output_factor.instantiate_cov_variables()
+      block.register_inverse()
+      block._input_factor.instantiate_inv_variables()
+      block._output_factor.instantiate_inv_variables()
+
+      # Ensure inverse update op doesn't crash.
+      sess.run(tf_variables.global_variables_initializer())
+      sess.run([
+          factor.make_inverse_update_ops()
+          for factor in layer_collection.get_factors()
+      ])
+
+      # Ensure inverse-vector multiply doesn't crash.
+      output = block.multiply_inverse(params)
+      sess.run(output)
+
+      # Ensure same shape.
+      self.assertAllEqual(output.shape, params.shape)
+
+
 class ConvKFCBasicFBTest(test.TestCase):
 
   def _testConvKFCBasicFBInitParams(self, params):
@@ -775,16 +823,17 @@ class ConvKFCBasicFBTest(test.TestCase):
         params = array_ops.constant(params)
       inputs = random_ops.random_normal((2, 2, 2))
       outputs = random_ops.random_normal((2, 2, 2))
-      block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME')
+      block = fb.ConvKFCBasicFB(
+          lc.LayerCollection(), params=params, padding='SAME')
       block.register_additional_minibatch(inputs, outputs)
 
       self.assertAllEqual([outputs], block.tensors_to_compute_grads())
 
   def testConvKFCBasicFBInitParamsParamsTuple(self):
-    self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)])
+    self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])])
 
   def testConvKFCBasicFBInitParamsParamsSingle(self):
-    self._testConvKFCBasicFBInitParams([np.array([1., 2.])])
+    self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])])
 
   def testMultiplyInverseTuple(self):
     with ops.Graph().as_default(), self.test_session() as sess:
@@ -792,8 +841,8 @@ class ConvKFCBasicFBTest(test.TestCase):
       params = random_ops.random_normal((2, 2, 2, 2))
       inputs = random_ops.random_normal((2, 2, 2, 2))
       outputs = random_ops.random_normal((2, 2, 2, 2))
-      block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
-                                'SAME')
+      block = fb.ConvKFCBasicFB(
+          lc.LayerCollection(), params=params, padding='SAME')
       block.register_additional_minibatch(inputs, outputs)
       grads = outputs**2
       block.instantiate_factors(((grads,),), 0.5)
@@ -823,8 +872,8 @@ class ConvKFCBasicFBTest(test.TestCase):
       params = random_ops.random_normal((2, 2, 2, 2))
       inputs = random_ops.random_normal((2, 2, 2, 2))
       outputs = random_ops.random_normal((2, 2, 2, 2))
-      block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
-                                'SAME')
+      block = fb.ConvKFCBasicFB(
+          lc.LayerCollection(), params=params, padding='SAME')
       block.register_additional_minibatch(inputs, outputs)
       self.assertFalse(block._has_bias)
       grads = outputs**2
@@ -851,8 +900,8 @@ class ConvKFCBasicFBTest(test.TestCase):
       params = [random_ops.random_normal((2, 2, 2, 2))]
       inputs = random_ops.random_normal((2, 2, 2, 2))
       outputs = random_ops.random_normal((2, 2, 2, 2))
-      block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
-                                'SAME')
+      block = fb.ConvKFCBasicFB(
+          lc.LayerCollection(), params=params, padding='SAME')
       block.register_additional_minibatch(inputs, outputs)
       self.assertTrue(block._has_bias)
       grads = outputs**2
@@ -879,8 +928,8 @@ class ConvKFCBasicFBTest(test.TestCase):
       params = array_ops.zeros((2, 2, 2, 2))
       inputs = array_ops.zeros((2, 2, 2, 2))
       outputs = array_ops.zeros((2, 2, 2, 2))
-      block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1),
-                                'SAME')
+      block = fb.ConvKFCBasicFB(
+          lc.LayerCollection(), params=params, padding='SAME')
       block.register_additional_minibatch(inputs, outputs)
       grads = outputs**2
       damping = 0.  # This test is only valid without damping.
index beb427b..16f02f1 100644 (file)
@@ -23,12 +23,14 @@ import numpy.random as npr
 
 from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
 from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops as tf_ops
 from tensorflow.python.framework import random_seed
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import variables as tf_variables
 from tensorflow.python.platform import test
 
@@ -447,6 +449,117 @@ class EmbeddingInputKroneckerFactorTest(test.TestCase):
         self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov)
 
 
+class ConvDiagonalFactorTest(test.TestCase):
+
+  def setUp(self):
+    self.batch_size = 10
+    self.height = self.width = 32
+    self.in_channels = 3
+    self.out_channels = 1
+    self.kernel_height = self.kernel_width = 3
+    self.strides = [1, 2, 2, 1]
+    self.data_format = 'NHWC'
+    self.padding = 'SAME'
+    self.kernel_shape = [
+        self.kernel_height, self.kernel_width, self.in_channels,
+        self.out_channels
+    ]
+
+  def testInit(self):
+    with tf_ops.Graph().as_default():
+      inputs = random_ops.random_uniform(
+          [self.batch_size, self.height, self.width, self.in_channels])
+      outputs_grads = [
+          random_ops.random_uniform([
+              self.batch_size, self.height // self.strides[1],
+              self.width // self.strides[2], self.out_channels
+          ]) for _ in range(3)
+      ]
+
+      factor = ff.ConvDiagonalFactor(
+          inputs,
+          outputs_grads,
+          self.kernel_shape,
+          self.strides,
+          self.padding,
+          data_format=self.data_format)
+      factor.instantiate_cov_variables()
+
+      # Ensure covariance matrix's shape makes sense.
+      self.assertEqual([
+          self.kernel_height * self.kernel_width * self.in_channels,
+          self.out_channels
+      ],
+                       factor.get_cov_var().shape.as_list())
+
+  def testMakeCovarianceUpdateOp(self):
+    with tf_ops.Graph().as_default():
+      # Construct all arguments such that convolution kernel is applied in
+      # exactly one spatial location.
+      inputs = np.random.randn(
+          1,  # batch_size
+          self.kernel_height,
+          self.kernel_width,
+          self.in_channels)  # in_channels
+      outputs_grad = np.random.randn(
+          1,  # batch_size
+          1,  # output_height
+          1,  # output_width
+          self.out_channels)
+
+      factor = ff.ConvDiagonalFactor(
+          constant_op.constant(inputs), [constant_op.constant(outputs_grad)],
+          self.kernel_shape,
+          strides=[1, 1, 1, 1],
+          padding='VALID')
+      factor.instantiate_cov_variables()
+
+      # Completely forget initial value on first update.
+      cov_update_op = factor.make_covariance_update_op(0.0)
+
+      # Ensure new covariance value is same as outer-product of inputs/outputs
+      # vectorized, squared.
+      with self.test_session() as sess:
+        sess.run(tf_variables.global_variables_initializer())
+        cov = sess.run(cov_update_op)
+        expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2
+        self.assertAllClose(expected_cov, cov)
+
+  def testHasBias(self):
+    with tf_ops.Graph().as_default():
+      inputs = random_ops.random_uniform(
+          [self.batch_size, self.height, self.width, self.in_channels])
+      outputs_grads = [
+          random_ops.random_uniform([
+              self.batch_size, self.height // self.strides[1],
+              self.width // self.strides[2], self.out_channels
+          ]) for _ in range(3)
+      ]
+
+      factor = ff.ConvDiagonalFactor(
+          inputs,
+          outputs_grads,
+          self.kernel_shape,
+          self.strides,
+          self.padding,
+          data_format=self.data_format,
+          has_bias=True)
+      factor.instantiate_cov_variables()
+
+      # Ensure shape accounts for bias.
+      self.assertEqual([
+          self.kernel_height * self.kernel_width * self.in_channels + 1,
+          self.out_channels
+      ],
+                       factor.get_cov_var().shape.as_list())
+
+      # Ensure update op doesn't crash.
+      cov_update_op = factor.make_covariance_update_op(0.0)
+      with self.test_session() as sess:
+        sess.run(tf_variables.global_variables_initializer())
+        sess.run(cov_update_op)
+
+
 class FullyConnectedKroneckerFactorTest(test.TestCase):
 
   def _testFullyConnectedKroneckerFactorInit(self,
@@ -493,24 +606,152 @@ class FullyConnectedKroneckerFactorTest(test.TestCase):
       self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
 
 
-class ConvInputKroneckerFactorTest(test.TestCase):
+class ConvFactorTestCase(test.TestCase):
+
+  def assertMatrixRank(self, rank, matrix, atol=1e-5):
+    assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.'
+    eigvals = np.linalg.eigvals(matrix)
+    nnz_eigvals = np.sum(eigvals > atol)
+    self.assertEqual(
+        rank,
+        nnz_eigvals,
+        msg=('Found %d of %d expected non-zero eigenvalues: %s.' %
+             (nnz_eigvals, rank, eigvals)))
+
+
+class ConvInputKroneckerFactorTest(ConvFactorTestCase):
+
+  def test3DConvolution(self):
+    with tf_ops.Graph().as_default():
+      batch_size = 1
+      width = 3
+      in_channels = 3**3
+      out_channels = 4
+
+      factor = ff.ConvInputKroneckerFactor(
+          inputs=random_ops.random_uniform(
+              (batch_size, width, width, width, in_channels), seed=0),
+          filter_shape=(width, width, width, in_channels, out_channels),
+          padding='SAME',
+          strides=(2, 2, 2),
+          extract_patches_fn='extract_convolution_patches',
+          has_bias=False)
+      factor.instantiate_cov_variables()
+
+      # Ensure shape of covariance matches input size of filter.
+      input_size = in_channels * (width**3)
+      self.assertEqual([input_size, input_size],
+                       factor.get_cov_var().shape.as_list())
+
+      # Ensure cov_update_op doesn't crash.
+      with self.test_session() as sess:
+        sess.run(tf_variables.global_variables_initializer())
+        sess.run(factor.make_covariance_update_op(0.0))
+        cov = sess.run(factor.get_cov_var())
+
+      # Cov should be rank-8, as the filter will be applied at each corner of
+      # the 4-D cube.
+      self.assertMatrixRank(8, cov)
+
+  def testPointwiseConv2d(self):
+    with tf_ops.Graph().as_default():
+      batch_size = 1
+      width = 3
+      in_channels = 3**2
+      out_channels = 4
+
+      factor = ff.ConvInputKroneckerFactor(
+          inputs=random_ops.random_uniform(
+              (batch_size, width, width, in_channels), seed=0),
+          filter_shape=(1, 1, in_channels, out_channels),
+          padding='SAME',
+          strides=(1, 1, 1, 1),
+          extract_patches_fn='extract_pointwise_conv2d_patches',
+          has_bias=False)
+      factor.instantiate_cov_variables()
+
+      # Ensure shape of covariance matches input size of filter.
+      self.assertEqual([in_channels, in_channels],
+                       factor.get_cov_var().shape.as_list())
+
+      # Ensure cov_update_op doesn't crash.
+      with self.test_session() as sess:
+        sess.run(tf_variables.global_variables_initializer())
+        sess.run(factor.make_covariance_update_op(0.0))
+        cov = sess.run(factor.get_cov_var())
+
+      # Cov should be rank-9, as the filter will be applied at each location.
+      self.assertMatrixRank(9, cov)
+
+  def testStrides(self):
+    with tf_ops.Graph().as_default():
+      batch_size = 1
+      width = 3
+      in_channels = 3**2
+      out_channels = 4
+
+      factor = ff.ConvInputKroneckerFactor(
+          inputs=random_ops.random_uniform(
+              (batch_size, width, width, in_channels), seed=0),
+          filter_shape=(1, 1, in_channels, out_channels),
+          padding='SAME',
+          strides=(1, 2, 1, 1),
+          extract_patches_fn='extract_image_patches',
+          has_bias=False)
+      factor.instantiate_cov_variables()
+
+      with self.test_session() as sess:
+        sess.run(tf_variables.global_variables_initializer())
+        sess.run(factor.make_covariance_update_op(0.0))
+        cov = sess.run(factor.get_cov_var())
+
+      # Cov should be the sum of 3 * 2 = 6 outer products.
+      self.assertMatrixRank(6, cov)
+
+  def testDilationRate(self):
+    with tf_ops.Graph().as_default():
+      batch_size = 1
+      width = 3
+      in_channels = 2
+      out_channels = 4
+
+      factor = ff.ConvInputKroneckerFactor(
+          inputs=random_ops.random_uniform(
+              (batch_size, width, width, in_channels), seed=0),
+          filter_shape=(3, 3, in_channels, out_channels),
+          padding='SAME',
+          extract_patches_fn='extract_image_patches',
+          strides=(1, 1, 1, 1),
+          dilation_rate=(1, width, width, 1),
+          has_bias=False)
+      factor.instantiate_cov_variables()
+
+      with self.test_session() as sess:
+        sess.run(tf_variables.global_variables_initializer())
+        sess.run(factor.make_covariance_update_op(0.0))
+        cov = sess.run(factor.get_cov_var())
+
+      # Cov should be rank = in_channels, as only the center of the filter
+      # receives non-zero input for each input channel.
+      self.assertMatrixRank(in_channels, cov)
 
   def testConvInputKroneckerFactorInitNoBias(self):
     with tf_ops.Graph().as_default():
-      random_seed.set_random_seed(200)
-      tensor = array_ops.ones((2, 3), name='a/b/c')
+      tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
       factor = ff.ConvInputKroneckerFactor(
-          tensor, (1, 2, 3, 4), 3, 2, has_bias=False)
+          inputs=tensor,
+          filter_shape=(1, 2, 3, 4),
+          padding='SAME',
+          has_bias=False)
       factor.instantiate_cov_variables()
       self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
                        factor.get_cov().get_shape().as_list())
 
   def testConvInputKroneckerFactorInit(self):
     with tf_ops.Graph().as_default():
-      random_seed.set_random_seed(200)
-      tensor = array_ops.ones((2, 3), name='a/b/c')
+      tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
       factor = ff.ConvInputKroneckerFactor(
-          tensor, (1, 2, 3, 4), 3, 2, has_bias=True)
+          tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
       factor.instantiate_cov_variables()
       self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
                        factor.get_cov().get_shape().as_list())
@@ -518,10 +759,9 @@ class ConvInputKroneckerFactorTest(test.TestCase):
   def testConvInputKroneckerFactorInitFloat64(self):
     with tf_ops.Graph().as_default():
       dtype = dtypes.float64_ref
-      random_seed.set_random_seed(200)
-      tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
+      tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64)
       factor = ff.ConvInputKroneckerFactor(
-          tensor, (1, 2, 3, 4), 3, 2, has_bias=True)
+          tensor, filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
       factor.instantiate_cov_variables()
       cov = factor.get_cov()
       self.assertEqual(cov.dtype, dtype)
@@ -530,33 +770,60 @@ class ConvInputKroneckerFactorTest(test.TestCase):
 
   def testMakeCovarianceUpdateOpWithBias(self):
     with tf_ops.Graph().as_default(), self.test_session() as sess:
-      random_seed.set_random_seed(200)
+      input_shape = (2, 1, 1, 1)
       tensor = array_ops.constant(
-          np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32)
+          np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
+              np.float32))
       factor = ff.ConvInputKroneckerFactor(
-          tensor, (1, 2, 1, 1), [1, 1, 1, 1], 'SAME', has_bias=True)
+          tensor, filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True)
       factor.instantiate_cov_variables()
 
       sess.run(tf_variables.global_variables_initializer())
-      new_cov = sess.run(factor.make_covariance_update_op(.5))
-      self.assertAllClose([[34.375, 37, 3.125], [37, 41, 3.5], [3.125, 3.5, 1]],
-                          new_cov)
+      new_cov = sess.run(factor.make_covariance_update_op(0.))
+      self.assertAllClose(
+          [
+              [(1. + 4.) / 2., (1. + 2.) / 2.],  #
+              [(1. + 2.) / 2., (1. + 1.) / 2.]
+          ],  #
+          new_cov)
 
   def testMakeCovarianceUpdateOpNoBias(self):
     with tf_ops.Graph().as_default(), self.test_session() as sess:
-      random_seed.set_random_seed(200)
+      input_shape = (2, 1, 1, 1)
       tensor = array_ops.constant(
-          np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32)
-      factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1),
-                                           [1, 1, 1, 1], 'SAME')
+          np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
+              np.float32))
+      factor = ff.ConvInputKroneckerFactor(
+          tensor, filter_shape=(1, 1, 1, 1), padding='SAME')
       factor.instantiate_cov_variables()
 
       sess.run(tf_variables.global_variables_initializer())
-      new_cov = sess.run(factor.make_covariance_update_op(.5))
-      self.assertAllClose([[34.375, 37], [37, 41]], new_cov)
+      new_cov = sess.run(factor.make_covariance_update_op(0.))
+      self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
 
 
-class ConvOutputKroneckerFactorTest(test.TestCase):
+class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
+
+  def test3DConvolution(self):
+    with tf_ops.Graph().as_default():
+      batch_size = 1
+      width = 3
+      out_channels = width**3
+
+      factor = ff.ConvOutputKroneckerFactor(outputs_grads=[
+          random_ops.random_uniform(
+              (batch_size, width, width, width, out_channels), seed=0)
+      ])
+      factor.instantiate_cov_variables()
+
+      with self.test_session() as sess:
+        sess.run(tf_variables.global_variables_initializer())
+        sess.run(factor.make_covariance_update_op(0.0))
+        cov = sess.run(factor.get_cov())
+
+      # Cov should be rank 3^3, as each spatial position donates a rank-1
+      # update.
+      self.assertMatrixRank(width**3, cov)
 
   def testConvOutputKroneckerFactorInit(self):
     with tf_ops.Graph().as_default():
@@ -577,13 +844,6 @@ class ConvOutputKroneckerFactorTest(test.TestCase):
       self.assertEqual(cov.dtype, dtype)
       self.assertEqual([5, 5], cov.get_shape().as_list())
 
-  def testConvOutputKroneckerFactorInitNotEnoughDims(self):
-    with tf_ops.Graph().as_default():
-      random_seed.set_random_seed(200)
-      tensor = array_ops.ones((2, 3), name='a/b/c')
-      with self.assertRaises(IndexError):
-        ff.ConvOutputKroneckerFactor((tensor,))
-
   def testMakeCovarianceUpdateOp(self):
     with tf_ops.Graph().as_default(), self.test_session() as sess:
       random_seed.set_random_seed(200)
index 889f336..bae6bd7 100644 (file)
@@ -104,14 +104,31 @@ class LayerCollectionTest(test.TestCase):
           array_ops.constant(3),
           approx=layer_collection.APPROX_DIAGONAL_NAME)
       lc.register_conv2d(
-          array_ops.constant(4), [1, 1, 1, 1], 'SAME',
-          array_ops.ones((1, 1, 1, 1)), array_ops.constant(3))
+          params=array_ops.ones((2, 3, 4, 5)),
+          strides=[1, 1, 1, 1],
+          padding='SAME',
+          inputs=array_ops.ones((1, 2, 3, 4)),
+          outputs=array_ops.ones((1, 1, 1, 5)))
       lc.register_conv2d(
-          array_ops.constant(4), [1, 1, 1, 1],
-          'SAME',
-          array_ops.ones((1, 1, 1, 1)),
-          array_ops.constant(3),
+          params=array_ops.ones((2, 3, 4, 5)),
+          strides=[1, 1, 1, 1],
+          padding='SAME',
+          inputs=array_ops.ones((1, 2, 3, 4)),
+          outputs=array_ops.ones((1, 1, 1, 5)),
           approx=layer_collection.APPROX_DIAGONAL_NAME)
+      lc.register_separable_conv2d(
+          depthwise_params=array_ops.ones((3, 3, 1, 2)),
+          pointwise_params=array_ops.ones((1, 1, 2, 4)),
+          inputs=array_ops.ones((32, 5, 5, 1)),
+          depthwise_outputs=array_ops.ones((32, 5, 5, 2)),
+          pointwise_outputs=array_ops.ones((32, 5, 5, 4)),
+          strides=[1, 1, 1, 1],
+          padding='SAME')
+      lc.register_convolution(
+          params=array_ops.ones((3, 3, 1, 8)),
+          inputs=array_ops.ones((32, 5, 5, 1)),
+          outputs=array_ops.ones((32, 5, 5, 8)),
+          padding='SAME')
       lc.register_generic(
           array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
       lc.register_generic(
@@ -119,7 +136,7 @@ class LayerCollectionTest(test.TestCase):
           16,
           approx=layer_collection.APPROX_DIAGONAL_NAME)
 
-      self.assertEqual(6, len(lc.get_blocks()))
+      self.assertEqual(9, len(lc.get_blocks()))
 
   def testRegisterBlocksMultipleRegistrations(self):
     with ops.Graph().as_default():
@@ -535,6 +552,32 @@ class LayerCollectionTest(test.TestCase):
     self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB)
     self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)
 
+  def testDefaultLayerCollection(self):
+    with ops.Graph().as_default():
+      # Can't get default if there isn't one set.
+      with self.assertRaises(ValueError):
+        layer_collection.get_default_layer_collection()
+
+      # Can't set default twice.
+      lc = layer_collection.LayerCollection()
+      layer_collection.set_default_layer_collection(lc)
+      with self.assertRaises(ValueError):
+        layer_collection.set_default_layer_collection(lc)
+
+      # Same as one set.
+      self.assertTrue(lc is layer_collection.get_default_layer_collection())
+
+      # Can set to None.
+      layer_collection.set_default_layer_collection(None)
+      with self.assertRaises(ValueError):
+        layer_collection.get_default_layer_collection()
+
+      # as_default() is the same as setting/clearing.
+      with lc.as_default():
+        self.assertTrue(lc is layer_collection.get_default_layer_collection())
+      with self.assertRaises(ValueError):
+        layer_collection.get_default_layer_collection()
+
 
 if __name__ == '__main__':
   test.main()
index 97a97ad..2cee012 100644 (file)
@@ -29,6 +29,8 @@ from tensorflow.python.framework import random_seed
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import linalg_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
@@ -325,6 +327,84 @@ class UtilsTest(test.TestCase):
           ],
           values)
 
+  def testExtractConvolutionPatches(self):
+    with ops.Graph().as_default(), self.test_session() as sess:
+      batch_size = 10
+      image_spatial_shape = [9, 10, 11]
+      in_channels = out_channels = 32
+      kernel_spatial_shape = [5, 3, 3]
+      spatial_strides = [1, 2, 1]
+      spatial_dilation = [1, 1, 1]
+      padding = 'SAME'
+
+      images = random_ops.random_uniform(
+          [batch_size] + image_spatial_shape + [in_channels], seed=0)
+      kernel_shape = kernel_spatial_shape + [in_channels, out_channels]
+      kernel = random_ops.random_uniform(kernel_shape, seed=1)
+
+      # Ensure shape matches expectation.
+      patches = utils.extract_convolution_patches(
+          images,
+          kernel_shape,
+          padding,
+          strides=spatial_strides,
+          dilation_rate=spatial_dilation)
+      result_spatial_shape = (
+          patches.shape.as_list()[1:1 + len(image_spatial_shape)])
+      self.assertEqual(patches.shape.as_list(),
+                       [batch_size] + result_spatial_shape +
+                       kernel_spatial_shape + [in_channels])
+
+      # Ensure extract...patches() + matmul() and convolution() implementation
+      # give the same answer.
+      outputs = nn_ops.convolution(
+          images,
+          kernel,
+          padding,
+          strides=spatial_strides,
+          dilation_rate=spatial_dilation)
+
+      patches_flat = array_ops.reshape(
+          patches, [-1, np.prod(kernel_spatial_shape) * in_channels])
+      kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
+      outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
+
+      outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
+      self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
+
+  def testExtractPointwiseConv2dPatches(self):
+    with ops.Graph().as_default(), self.test_session() as sess:
+      batch_size = 10
+      image_height = image_width = 8
+      in_channels = out_channels = 3
+      kernel_height = kernel_width = 1
+      strides = [1, 1, 1, 1]
+      padding = 'VALID'
+
+      images = random_ops.random_uniform(
+          [batch_size, image_height, image_width, in_channels], seed=0)
+      kernel_shape = [kernel_height, kernel_width, in_channels, out_channels]
+      kernel = random_ops.random_uniform(kernel_shape, seed=1)
+
+      # Ensure shape matches expectation.
+      patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape)
+      self.assertEqual(patches.shape.as_list(), [
+          batch_size, image_height, image_width, kernel_height, kernel_width,
+          in_channels
+      ])
+
+      # Ensure extract...patches() + matmul() and conv2d() implementation
+      # give the same answer.
+      outputs = nn_ops.conv2d(images, kernel, strides, padding)
+
+      patches_flat = array_ops.reshape(
+          patches, [-1, kernel_height * kernel_width * in_channels])
+      kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
+      outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
+
+      outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
+      self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
+
 
 if __name__ == '__main__':
   test.main()
index 521a988..31f4689 100644 (file)
@@ -40,10 +40,12 @@ from __future__ import print_function
 import abc
 import enum  # pylint: disable=g-bad-import-order
 
+import numpy as np
 import six
 
 from tensorflow.contrib.kfac.python.ops import fisher_factors
 from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 
@@ -517,7 +519,7 @@ class FullyConnectedDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
 
 
 class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
-  """FisherBlock for convolutional layers using a diagonal approx.
+  """FisherBlock for 2-D convolutional layers using a diagonal approx.
 
   Estimates the Fisher Information matrix's diagonal entries for a convolutional
   layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares"
@@ -541,7 +543,13 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
   to the layer's parameters 'w'.
   """
 
-  def __init__(self, layer_collection, params, strides, padding):
+  def __init__(self,
+               layer_collection,
+               params,
+               strides,
+               padding,
+               data_format=None,
+               dilations=None):
     """Creates a ConvDiagonalFB block.
 
     Args:
@@ -553,29 +561,53 @@ class ConvDiagonalFB(InputOutputMultiMinibatch, FisherBlock):
         containing the previous and a Tensor of shape [out_channels].
       strides: The stride size in this layer (1-D Tensor of length 4).
       padding: The padding in this layer (e.g. "SAME").
+      data_format: str or None. Format of input data.
+      dilations: List of 4 ints or None. Rate for dilation along all dimensions.
+
+    Raises:
+      ValueError: if strides is not length-4.
+      ValueError: if dilations is not length-4.
+      ValueError: if channel is not last dimension.
     """
-    self._strides = tuple(strides) if isinstance(strides, list) else strides
+    if len(strides) != 4:
+      raise ValueError("strides must contain 4 numbers.")
+
+    if dilations is None:
+      dilations = [1, 1, 1, 1]
+
+    if len(dilations) != 4:
+      raise ValueError("dilations must contain 4 numbers.")
+
+    if not utils.is_data_format_channel_last(data_format):
+      raise ValueError("data_format must be channels-last.")
+
+    self._strides = maybe_tuple(strides)
     self._padding = padding
+    self._data_format = data_format
+    self._dilations = maybe_tuple(dilations)
     self._has_bias = isinstance(params, (tuple, list))
 
     fltr = params[0] if self._has_bias else params
     self._filter_shape = tuple(fltr.shape.as_list())
 
+    if len(self._filter_shape) != 4:
+      raise ValueError(
+          "Convolution filter must be of shape"
+          " [filter_height, filter_width, in_channels, out_channels].")
+
     super(ConvDiagonalFB, self).__init__(layer_collection)
 
   def instantiate_factors(self, grads_list, damping):
-    # Infer number of locations upon which convolution is applied.
-    inputs_shape = tuple(self._inputs[0].shape.as_list())
-    self._num_locations = (
-        inputs_shape[1] * inputs_shape[2] //
-        (self._strides[1] * self._strides[2]))
-
     inputs, grads_list = self._package_minibatches(grads_list)
 
+    # Infer number of locations upon which convolution is applied.
+    self._num_locations = num_conv_locations(inputs.shape.as_list(),
+                                             self._strides)
+
     self._factor = self._layer_collection.make_or_get_factor(
         fisher_factors.ConvDiagonalFactor,
-        (inputs, grads_list, self._filter_shape, self._strides,
-         self._padding, self._has_bias))
+        (inputs, grads_list, self._filter_shape, self._strides, self._padding,
+         self._data_format, self._dilations, self._has_bias))
 
     def damping_func():
       return self._num_locations * normalize_damping(damping,
@@ -658,8 +690,8 @@ class KroneckerProductFB(FisherBlock):
     reshaped_out = self._input_factor.left_multiply_matpower(
         reshaped_out, exp, self._input_damping_func)
     if self._renorm_coeff != 1.0:
-      reshaped_out *= math_ops.cast(
-          self._renorm_coeff**exp, dtype=reshaped_out.dtype)
+      renorm_coeff = math_ops.cast(self._renorm_coeff, dtype=reshaped_out.dtype)
+      reshaped_out *= math_ops.cast(renorm_coeff**exp, dtype=reshaped_out.dtype)
     return utils.mat2d_to_layer_params(vector, reshaped_out)
 
   def full_fisher_block(self):
@@ -761,7 +793,7 @@ class FullyConnectedKFACBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
 
 
 class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
-  """FisherBlock for 2D convolutional layers using the basic KFC approx.
+  """FisherBlock for convolutional layers using the basic KFC approx.
 
   Estimates the Fisher Information matrix's blog for a convolutional
   layer.
@@ -784,21 +816,40 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
   See equation 23 in https://arxiv.org/abs/1602.01407 for details.
   """
 
-  def __init__(self, layer_collection, params, strides, padding):
+  def __init__(self,
+               layer_collection,
+               params,
+               padding,
+               strides=None,
+               dilation_rate=None,
+               data_format=None,
+               extract_patches_fn=None):
     """Creates a ConvKFCBasicFB block.
 
     Args:
       layer_collection: The collection of all layers in the K-FAC approximate
           Fisher information matrix to which this FisherBlock belongs.
       params: The parameters (Tensor or tuple of Tensors) of this layer. If
-        kernel alone, a Tensor of shape [kernel_height, kernel_width,
+        kernel alone, a Tensor of shape [..spatial_filter_shape..,
         in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
         containing the previous and a Tensor of shape [out_channels].
-      strides: The stride size in this layer (1-D Tensor of length 4).
-      padding: The padding in this layer (1-D of Tensor length 4).
+      padding: str. Padding method.
+      strides: List of ints or None. Contains [..spatial_filter_strides..] if
+        'extract_patches_fn' is compatible with tf.nn.convolution(), else
+        [1, ..spatial_filter_strides, 1].
+      dilation_rate: List of ints or None. Rate for dilation along each spatial
+        dimension if 'extract_patches_fn' is compatible with
+        tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
+      data_format: str or None. Format of input data.
+      extract_patches_fn: str or None. Name of function that extracts image
+        patches. One of "extract_convolution_patches", "extract_image_patches",
+        "extract_pointwise_conv2d_patches".
     """
-    self._strides = tuple(strides) if isinstance(strides, list) else strides
     self._padding = padding
+    self._strides = maybe_tuple(strides)
+    self._dilation_rate = maybe_tuple(dilation_rate)
+    self._data_format = data_format
+    self._extract_patches_fn = extract_patches_fn
     self._has_bias = isinstance(params, (tuple, list))
 
     fltr = params[0] if self._has_bias else params
@@ -807,15 +858,16 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
     super(ConvKFCBasicFB, self).__init__(layer_collection)
 
   def instantiate_factors(self, grads_list, damping):
+    inputs, grads_list = self._package_minibatches(grads_list)
+
     # Infer number of locations upon which convolution is applied.
     self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
                                              self._strides)
 
-    inputs, grads_list = self._package_minibatches(grads_list)
-
     self._input_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.ConvInputKroneckerFactor,
-        (inputs, self._filter_shape, self._strides, self._padding,
+        (inputs, self._filter_shape, self._padding, self._strides,
+         self._dilation_rate, self._data_format, self._extract_patches_fn,
          self._has_bias))
     self._output_factor = self._layer_collection.make_or_get_factor(
         fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
@@ -827,17 +879,262 @@ class ConvKFCBasicFB(InputOutputMultiMinibatch, KroneckerProductFB):
     return self._num_locations
 
 
+class DepthwiseConvDiagonalFB(ConvDiagonalFB):
+  """FisherBlock for depthwise_conv2d().
+
+  Equivalent to ConvDiagonalFB applied to each input channel in isolation.
+  """
+
+  def __init__(self,
+               layer_collection,
+               params,
+               strides,
+               padding,
+               rate=None,
+               data_format=None):
+    """Creates a DepthwiseConvKFCBasicFB block.
+
+    Args:
+      layer_collection: The collection of all layers in the K-FAC approximate
+          Fisher information matrix to which this FisherBlock belongs.
+      params: Tensor of shape [filter_height, filter_width, in_channels,
+        channel_multiplier].
+      strides: List of 4 ints. Strides along all dimensions.
+      padding: str. Padding method.
+      rate: List of 4 ints or None. Rate for dilation along all dimensions.
+      data_format: str or None. Format of input data.
+
+    Raises:
+      NotImplementedError: If parameters contains bias.
+      ValueError: If filter is not 4-D.
+      ValueError: If strides is not length-4.
+      ValueError: If rates is not length-2.
+      ValueError: If channels are not last dimension.
+    """
+    if isinstance(params, (tuple, list)):
+      raise NotImplementedError("Bias not yet supported.")
+
+    if params.shape.ndims != 4:
+      raise ValueError("Filter must be 4-D.")
+
+    if len(strides) != 4:
+      raise ValueError("strides must account for 4 dimensions.")
+
+    if rate is not None:
+      if len(rate) != 2:
+        raise ValueError("rate must only account for spatial dimensions.")
+      rate = [1, rate[0], rate[1], 1]  # conv2d expects 4-element rate.
+
+    if not utils.is_data_format_channel_last(data_format):
+      raise ValueError("data_format must be channels-last.")
+
+    super(DepthwiseConvDiagonalFB, self).__init__(
+        layer_collection=layer_collection,
+        params=params,
+        strides=strides,
+        padding=padding,
+        dilations=rate,
+        data_format=data_format)
+
+    # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
+    filter_height, filter_width, in_channels, channel_multiplier = (
+        params.shape.as_list())
+    self._filter_shape = (filter_height, filter_width, in_channels,
+                          in_channels * channel_multiplier)
+
+  def multiply_matpower(self, vector, exp):
+    conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
+    conv2d_result = super(DepthwiseConvDiagonalFB, self).multiply_matpower(
+        conv2d_vector, exp)
+    return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
+
+
+class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
+  """FisherBlock for depthwise_conv2d().
+
+  Equivalent to ConvKFCBasicFB applied to each input channel in isolation.
+  """
+
+  def __init__(self,
+               layer_collection,
+               params,
+               strides,
+               padding,
+               rate=None,
+               data_format=None):
+    """Creates a DepthwiseConvKFCBasicFB block.
+
+    Args:
+      layer_collection: The collection of all layers in the K-FAC approximate
+          Fisher information matrix to which this FisherBlock belongs.
+      params: Tensor of shape [filter_height, filter_width, in_channels,
+        channel_multiplier].
+      strides: List of 4 ints. Strides along all dimensions.
+      padding: str. Padding method.
+      rate: List of 4 ints or None. Rate for dilation along all dimensions.
+      data_format: str or None. Format of input data.
+
+    Raises:
+      NotImplementedError: If parameters contains bias.
+      ValueError: If filter is not 4-D.
+      ValueError: If strides is not length-4.
+      ValueError: If rates is not length-2.
+      ValueError: If channels are not last dimension.
+    """
+    if isinstance(params, (tuple, list)):
+      raise NotImplementedError("Bias not yet supported.")
+
+    if params.shape.ndims != 4:
+      raise ValueError("Filter must be 4-D.")
+
+    if len(strides) != 4:
+      raise ValueError("strides must account for 4 dimensions.")
+
+    if rate is not None:
+      if len(rate) != 2:
+        raise ValueError("rate must only account for spatial dimensions.")
+      rate = [1, rate[0], rate[1], 1]  # conv2d expects 4-element rate.
+
+    if not utils.is_data_format_channel_last(data_format):
+      raise ValueError("data_format must be channels-last.")
+
+    super(DepthwiseConvKFCBasicFB, self).__init__(
+        layer_collection=layer_collection,
+        params=params,
+        padding=padding,
+        strides=strides,
+        dilation_rate=rate,
+        data_format=data_format,
+        extract_patches_fn="extract_image_patches")
+
+    # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
+    filter_height, filter_width, in_channels, channel_multiplier = (
+        params.shape.as_list())
+    self._filter_shape = (filter_height, filter_width, in_channels,
+                          in_channels * channel_multiplier)
+
+  def multiply_matpower(self, vector, exp):
+    conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
+    conv2d_result = super(DepthwiseConvKFCBasicFB, self).multiply_matpower(
+        conv2d_vector, exp)
+    return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
+
+
+def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None):  # pylint: disable=redefined-builtin
+  """Converts a convolution filter for use with conv2d.
+
+  Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's
+  compatible with tf.nn.conv2d().
+
+  Args:
+    filter: Tensor of shape [height, width, in_channels, channel_multiplier].
+    name: None or str. Name of Op.
+
+  Returns:
+    Tensor of shape [height, width, in_channels, out_channels].
+
+  """
+  with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter",
+                      [filter]):
+    filter = ops.convert_to_tensor(filter)
+    filter_height, filter_width, in_channels, channel_multiplier = (
+        filter.shape.as_list())
+
+    results = []
+    for i in range(in_channels):
+      # Slice out one in_channel's filter. Insert zeros around it to force it
+      # to affect that channel and that channel alone.
+      elements = []
+      if i > 0:
+        elements.append(
+            array_ops.zeros(
+                [filter_height, filter_width, i, channel_multiplier]))
+      elements.append(filter[:, :, i:(i + 1), :])
+      if i + 1 < in_channels:
+        elements.append(
+            array_ops.zeros([
+                filter_height, filter_width, in_channels - (i + 1),
+                channel_multiplier
+            ]))
+
+      # Concat along in_channel.
+      results.append(
+          array_ops.concat(elements, axis=-2, name="in_channel_%d" % i))
+
+    # Concat along out_channel.
+    return array_ops.concat(results, axis=-1, name="out_channel")
+
+
+def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None):  # pylint: disable=redefined-builtin
+  """Converts a convolution filter for use with depthwise_conv2d.
+
+  Transforms a filter for use with tf.nn.conv2d() to one that's
+  compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along
+  the diagonal.
+
+  Args:
+    filter: Tensor of shape [height, width, in_channels, out_channels].
+    name: None or str. Name of Op.
+
+  Returns:
+    Tensor of shape,
+      [height, width, in_channels, channel_multiplier]
+
+  Raises:
+    ValueError: if out_channels is not evenly divisible by in_channels.
+  """
+  with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter",
+                      [filter]):
+    filter = ops.convert_to_tensor(filter)
+    filter_height, filter_width, in_channels, out_channels = (
+        filter.shape.as_list())
+
+    if out_channels % in_channels != 0:
+      raise ValueError("out_channels must be evenly divisible by in_channels.")
+    channel_multiplier = out_channels // in_channels
+
+    results = []
+    filter = array_ops.reshape(filter, [
+        filter_height, filter_width, in_channels, in_channels,
+        channel_multiplier
+    ])
+    for i in range(in_channels):
+      # Slice out output corresponding to the correct filter.
+      filter_slice = array_ops.reshape(
+          filter[:, :, i, i, :],
+          [filter_height, filter_width, 1, channel_multiplier])
+      results.append(filter_slice)
+
+    # Concat along out_channel.
+    return array_ops.concat(results, axis=-2, name="in_channels")
+
+
+def maybe_tuple(obj):
+  if not isinstance(obj, list):
+    return obj
+  return tuple(obj)
+
+
 def num_conv_locations(input_shape, strides):
   """Returns the number of spatial locations a 2D Conv kernel is applied to.
 
   Args:
-    input_shape: list representing shape of inputs to the Conv layer.
-    strides: list representing strides for the Conv kernel.
+    input_shape: List of ints representing shape of inputs to
+      tf.nn.convolution().
+    strides: List of ints representing strides along spatial dimensions as
+      passed in to tf.nn.convolution().
 
   Returns:
     A scalar |T| denoting the number of spatial locations for the Conv layer.
   """
-  return input_shape[1] * input_shape[2] // (strides[1] * strides[2])
+  spatial_input_locations = np.prod(input_shape[1:-1])
+
+  if strides is None:
+    spatial_strides_divisor = 1
+  else:
+    spatial_strides_divisor = np.prod(strides)
+
+  return spatial_input_locations // spatial_strides_divisor
 
 
 class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
@@ -858,7 +1155,7 @@ class FullyConnectedMultiIndepFB(InputOutputMultiMinibatch, KroneckerProductFB):
 
   def instantiate_factors(self, grads_list, damping):
 
-    self._num_uses = len(self._inputs[0])
+    self._num_uses = float(len(self._inputs[0]))
     inputs, grads_list = self._package_minibatches_multi(grads_list)
 
     self._input_factor = self._layer_collection.make_or_get_factor(
index 8ac63bc..6fc163e 100644 (file)
@@ -159,7 +159,9 @@ def scope_string_from_params(params):
 
   name_parts = []
   for param in params:
-    if isinstance(param, (tuple, list)):
+    if param is None:
+      name_parts.append("None")
+    elif isinstance(param, (tuple, list)):
       if all([isinstance(p, int) for p in param]):
         name_parts.append("-".join([str(p) for p in param]))
       else:
@@ -867,6 +869,8 @@ class ConvDiagonalFactor(DiagonalFactor):
                filter_shape,
                strides,
                padding,
+               data_format=None,
+               dilations=None,
                has_bias=False):
     """Creates a ConvDiagonalFactor object.
 
@@ -880,15 +884,42 @@ class ConvDiagonalFactor(DiagonalFactor):
         out_channels). Represents shape of kernel used in this layer.
       strides: The stride size in this layer (1-D Tensor of length 4).
       padding: The padding in this layer (1-D of Tensor length 4).
+      data_format: None or str. Format of conv2d inputs.
+      dilations: None or tuple of 4 ints.
       has_bias: Python bool. If True, the layer is assumed to have a bias
         parameter in addition to its filter parameter.
+
+    Raises:
+      ValueError: If inputs, output_grads, and filter_shape do not agree on
+        in_channels or out_channels.
+      ValueError: If strides, dilations are not length-4 lists of ints.
+      ValueError: If data_format does not put channel last.
     """
+    if not utils.is_data_format_channel_last(data_format):
+      raise ValueError("Channel must be last.")
+    if inputs.shape.ndims != 4:
+      raise ValueError("inputs must be 4-D Tensor.")
+    if inputs.shape.as_list()[-1] != filter_shape[-2]:
+      raise ValueError("inputs and filter_shape must agree on in_channels.")
+    for i, outputs_grad in enumerate(outputs_grads):
+      if outputs_grad.shape.ndims != 4:
+        raise ValueError("outputs[%d] must be 4-D Tensor." % i)
+      if outputs_grad.shape.as_list()[-1] != filter_shape[-1]:
+        raise ValueError(
+            "outputs[%d] and filter_shape must agree on out_channels." % i)
+    if len(strides) != 4:
+      raise ValueError("strides must be length-4 list of ints.")
+    if dilations is not None and len(dilations) != 4:
+      raise ValueError("dilations must be length-4 list of ints.")
+
     self._inputs = inputs
+    self._outputs_grads = outputs_grads
     self._filter_shape = filter_shape
     self._strides = strides
     self._padding = padding
+    self._data_format = data_format
+    self._dilations = dilations
     self._has_bias = has_bias
-    self._outputs_grads = outputs_grads
     self._patches = None
 
     super(ConvDiagonalFactor, self).__init__()
@@ -919,11 +950,15 @@ class ConvDiagonalFactor(DiagonalFactor):
 
     # TODO(b/64144716): there is potential here for a big savings in terms
     # of memory use.
+    if self._dilations is None:
+      rates = (1, 1, 1, 1)
+    else:
+      rates = tuple(self._dilations)
     patches = array_ops.extract_image_patches(
         self._inputs,
         ksizes=[1, filter_height, filter_width, 1],
         strides=self._strides,
-        rates=[1, 1, 1, 1],
+        rates=rates,
         padding=self._padding)
 
     if self._has_bias:
@@ -1010,39 +1045,55 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
   def __init__(self,
                inputs,
                filter_shape,
-               strides,
                padding,
+               strides=None,
+               dilation_rate=None,
+               data_format=None,
+               extract_patches_fn=None,
                has_bias=False):
     """Initializes ConvInputKroneckerFactor.
 
     Args:
-      inputs: A Tensor of shape [batch_size, height, width, in_channels]
-        which is the inputs to the layer (before being processed into patches).
-      filter_shape: 1-D Tensor of length 4. Contains [kernel_height,
-        kernel_width, in_channels, out_channels].
-      strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride,
-        width_stride, in_channel_stride].
+      inputs: Tensor of shape [batch_size, ..spatial_input_size.., in_channels].
+        Inputs to layer.
+      filter_shape: List of ints. Contains [..spatial_filter_size..,
+        in_channels, out_channels]. Shape of convolution kernel.
       padding: str. Padding method for layer. "SAME" or "VALID".
+      strides: List of ints or None. Contains [..spatial_filter_strides..] if
+        'extract_patches_fn' is compatible with tf.nn.convolution(), else
+        [1, ..spatial_filter_strides, 1].
+      dilation_rate: List of ints or None. Rate for dilation along each spatial
+        dimension if 'extract_patches_fn' is compatible with
+        tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
+      data_format: str or None. Format of input data.
+      extract_patches_fn: str or None. Name of function that extracts image
+        patches. One of "extract_convolution_patches", "extract_image_patches",
+        "extract_pointwise_conv2d_patches".
       has_bias: bool. If True, append 1 to in_channel.
     """
+    self._inputs = inputs
     self._filter_shape = filter_shape
     self._strides = strides
     self._padding = padding
+    self._dilation_rate = dilation_rate
+    self._data_format = data_format
+    self._extract_patches_fn = extract_patches_fn
     self._has_bias = has_bias
-    self._inputs = inputs
+
     super(ConvInputKroneckerFactor, self).__init__()
 
   @property
   def _var_scope(self):
     return "ff_convinkron_" + scope_string_from_params([
         self._inputs, self._filter_shape, self._strides, self._padding,
-        self._has_bias
+        self._dilation_rate, self._data_format, self._has_bias
     ])
 
   @property
   def _cov_shape(self):
-    filter_height, filter_width, in_channels, _ = self._filter_shape
-    size = filter_height * filter_width * in_channels + self._has_bias
+    spatial_filter_shape = self._filter_shape[0:-2]
+    in_channels = self._filter_shape[-2]
+    size = np.prod(spatial_filter_shape) * in_channels + self._has_bias
     return [size, size]
 
   @property
@@ -1057,18 +1108,44 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
     if idx != 0:
       raise ValueError("ConvInputKroneckerFactor only supports idx = 0")
 
-    filter_height, filter_width, in_channels, _ = self._filter_shape
-
     # TODO(b/64144716): there is potential here for a big savings in terms of
     # memory use.
-    patches = array_ops.extract_image_patches(
-        self._inputs,
-        ksizes=[1, filter_height, filter_width, 1],
-        strides=self._strides,
-        rates=[1, 1, 1, 1],
-        padding=self._padding)
+    if self._extract_patches_fn in [None, "extract_convolution_patches"]:
+      patches = utils.extract_convolution_patches(
+          self._inputs,
+          self._filter_shape,
+          padding=self._padding,
+          strides=self._strides,
+          dilation_rate=self._dilation_rate,
+          data_format=self._data_format)
+
+    elif self._extract_patches_fn == "extract_image_patches":
+      assert self._inputs.shape.ndims == 4
+      assert len(self._filter_shape) == 4
+      assert len(self._strides) == 4, self._strides
+      if self._dilation_rate is None:
+        rates = [1, 1, 1, 1]
+      else:
+        rates = self._dilation_rate
+        assert len(rates) == 4
+        assert rates[0] == rates[-1] == 1
+      patches = array_ops.extract_image_patches(
+          self._inputs,
+          ksizes=[1] + list(self._filter_shape[0:-2]) + [1],
+          strides=self._strides,
+          rates=rates,
+          padding=self._padding)
+
+    elif self._extract_patches_fn == "extract_pointwise_conv2d_patches":
+      assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)]
+      assert self._filter_shape[0] == self._filter_shape[1] == 1
+      patches = utils.extract_pointwise_conv2d_patches(
+          self._inputs, self._filter_shape, data_format=None)
 
-    flatten_size = (filter_height * filter_width * in_channels)
+    else:
+      raise NotImplementedError(self._extract_patches_fn)
+
+    flatten_size = np.prod(self._filter_shape[0:-1])
     # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde
     # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14),
     # where M = minibatch size, |T| = number of spatial locations,
@@ -1100,14 +1177,21 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor):
   Section 3.1 Estimating the factors.
   """
 
-  def __init__(self, outputs_grads):
+  def __init__(self, outputs_grads, data_format=None):
     """Initializes ConvOutputKroneckerFactor.
 
     Args:
-      outputs_grads: List of Tensors, each of shape [batch_size,
-        height, width, out_channels].  One Tensor for each "source".
+      outputs_grads: list of Tensors. Each Tensor is of shape
+          [batch_size, ..spatial_input_size.., out_channels]. One Tensor per
+          source.
+      data_format: None or str. Format of outputs_grads.
+
+    Raises:
+      ValueError: If channels are not final dimension.
     """
-    self._out_channels = outputs_grads[0].shape.as_list()[3]
+    if not utils.is_data_format_channel_last(data_format):
+      raise ValueError("Channel must be last.")
+    self._out_channels = outputs_grads[0].shape.as_list()[-1]
     self._outputs_grads = outputs_grads
     super(ConvOutputKroneckerFactor, self).__init__()
 
@@ -1433,4 +1517,3 @@ class FullyConnectedMultiKF(InverseProvidingFactor):
     return [control_flow_ops.group(*ops)]
 
     # pylint: enable=invalid-name
-
index 60894ed..4eb5e4c 100644 (file)
@@ -26,6 +26,7 @@ from __future__ import print_function
 
 from collections import defaultdict
 from collections import OrderedDict
+from contextlib import contextmanager
 from functools import partial
 
 import math
@@ -75,6 +76,27 @@ _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = {
 # tf.get_variable_scope().reuse.
 VARIABLE_SCOPE = "VARIABLE_SCOPE"
 
+_DEFAULT_LAYER_COLLECTION = None
+
+
+def get_default_layer_collection():
+  """Get default LayerCollection."""
+  if _DEFAULT_LAYER_COLLECTION is None:
+    raise ValueError(
+        "Attempted to retrieve default LayerCollection when none is set. Use "
+        "LayerCollection.as_default().")
+
+  return _DEFAULT_LAYER_COLLECTION
+
+
+def set_default_layer_collection(layer_collection):
+  global _DEFAULT_LAYER_COLLECTION
+
+  if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None:
+    raise ValueError("Default LayerCollection is already set.")
+
+  _DEFAULT_LAYER_COLLECTION = layer_collection
+
 
 class LayerParametersDict(OrderedDict):
   """An OrderedDict where keys are Tensors or tuples of Tensors.
@@ -594,21 +616,25 @@ class LayerCollection(object):
                       padding,
                       inputs,
                       outputs,
+                      data_format=None,
+                      dilations=None,
                       approx=None,
                       reuse=VARIABLE_SCOPE):
-    """Registers a convolutional layer.
+    """Registers a call to tf.nn.conv2d().
 
     Args:
       params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
         this layer. Weight matrix should have shape [kernel_height,
         kernel_width, in_channels, out_channels].  Bias should have shape
         [out_channels].
-      strides: 1-D Tensor of length 4. Strides for convolution kernel.
+      strides: List of 4 ints. Strides for convolution kernel.
       padding: string. see tf.nn.conv2d for valid values.
       inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
         to layer.
       outputs: Tensor of shape [batch_size, height, width, out_channels].
         Output produced by layer.
+      data_format: str or None. Format of data.
+      dilations: List of 4 ints. Dilations along each dimension.
       approx: str. One of "kron" or "diagonal".
       reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
         create a new FisherBlock.  If "VARIABLE_SCOPE", use
@@ -629,12 +655,206 @@ class LayerCollection(object):
       raise ValueError("Bad value {} for approx.".format(approx))
 
     block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx]
+    if approx == APPROX_KRONECKER_NAME:
+      block = self.register_block(
+          params,
+          block_type(
+              layer_collection=self,
+              params=params,
+              padding=padding,
+              strides=strides,
+              data_format=data_format,
+              dilation_rate=dilations,
+              extract_patches_fn="extract_image_patches"),
+          reuse=reuse)
+    elif approx == APPROX_DIAGONAL_NAME:
+      assert strides[0] == strides[-1] == 1
+      block = self.register_block(
+          params,
+          block_type(
+              layer_collection=self,
+              params=params,
+              padding=padding,
+              strides=strides,
+              dilations=dilations,
+              data_format=data_format),
+          reuse=reuse)
+    else:
+      raise NotImplementedError
+
+    block.register_additional_minibatch(inputs, outputs)
+
+    self._add_uses(params, 1)
+
+  def register_convolution(self,
+                           params,
+                           inputs,
+                           outputs,
+                           padding,
+                           strides=None,
+                           dilation_rate=None,
+                           data_format=None,
+                           approx=None,
+                           reuse=VARIABLE_SCOPE):
+    """Register a call to tf.nn.convolution().
+
+    Args:
+      params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
+        this layer. Weight matrix should have shape [..filter_spatial_size..,
+        in_channels, out_channels].  Bias should have shape [out_channels].
+      inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels].
+        Inputs to layer.
+      outputs: Tensor of shape [batch_size, ..output_spatial_size..,
+        out_channels].  Output produced by layer.
+      padding: string. see tf.nn.conv2d for valid values.
+      strides: List of ints of length len(..input_spatial_size..). Strides for
+        convolution kernel in spatial dimensions.
+      dilation_rate: List of ints of length len(..input_spatial_size..).
+        Dilations along spatial dimension.
+      data_format: str or None. Format of data.
+      approx: str. One of "kron" or "diagonal".
+      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
+        create a new FisherBlock.  If "VARIABLE_SCOPE", use
+        tf.get_variable_scope().reuse.
+
+    Raises:
+      ValueError: For improper value to 'approx'.
+      KeyError: If reuse == True but no FisherBlock found for 'params'.
+      ValueError: If reuse == True and FisherBlock found but of the wrong type.
+    """
+    assert approx is None or approx == APPROX_KRONECKER_NAME
+
     block = self.register_block(
-        params, block_type(self, params, strides, padding), reuse=reuse)
+        params,
+        fb.ConvKFCBasicFB(
+            layer_collection=self,
+            params=params,
+            padding=padding,
+            strides=strides,
+            dilation_rate=dilation_rate,
+            data_format=data_format),
+        reuse=reuse)
     block.register_additional_minibatch(inputs, outputs)
 
     self._add_uses(params, 1)
 
+  def register_depthwise_conv2d(self,
+                                params,
+                                inputs,
+                                outputs,
+                                strides,
+                                padding,
+                                rate=None,
+                                data_format=None,
+                                approx=None,
+                                reuse=VARIABLE_SCOPE):
+    """Register a call to tf.nn.depthwise_conv2d().
+
+    Args:
+      params: 4-D Tensor of shape [filter_height, filter_width,
+        in_channels, channel_multiplier].  Convolutional filter.
+      inputs: Tensor of shape [batch_size, input_height, input_width,
+        in_channels].  Inputs to layer.
+      outputs: Tensor of shape [batch_size, output_height, output_width,
+        in_channels * channel_multiplier].  Output produced by depthwise conv2d.
+      strides: List of ints of length 4. Strides along all dimensions.
+      padding: string. see tf.nn.conv2d for valid values.
+      rate: None or List of ints of length 2. Dilation rates in spatial
+        dimensions.
+      data_format: str or None. Format of data.
+      approx: None or str. Must be "diagonal" if non-None.
+      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
+        create a new FisherBlock.  If "VARIABLE_SCOPE", use
+        tf.get_variable_scope().reuse.
+
+    Raises:
+      ValueError: For improper value to 'approx'.
+      KeyError: If reuse == True but no FisherBlock found for 'params'.
+      ValueError: If reuse == True and FisherBlock found but of the wrong type.
+    """
+    assert approx is None or approx == APPROX_DIAGONAL_NAME
+    assert data_format in [None, "NHWC"]
+
+    block = self.register_block(
+        params,
+        fb.DepthwiseConvDiagonalFB(
+            layer_collection=self,
+            params=params,
+            strides=strides,
+            padding=padding,
+            rate=rate,
+            data_format=data_format),
+        reuse=reuse)
+    block.register_additional_minibatch(inputs, outputs)
+
+    self._add_uses(params, 1)
+
+  def register_separable_conv2d(self,
+                                depthwise_params,
+                                pointwise_params,
+                                inputs,
+                                depthwise_outputs,
+                                pointwise_outputs,
+                                strides,
+                                padding,
+                                rate=None,
+                                data_format=None,
+                                approx=None,
+                                reuse=VARIABLE_SCOPE):
+    """Register a call to tf.nn.separable_conv2d().
+
+    Note: This requires access to intermediate outputs betwee depthwise and
+    pointwise convolutions.
+
+    Args:
+      depthwise_params: 4-D Tensor of shape [filter_height, filter_width,
+        in_channels, channel_multiplier].  Filter for depthwise conv2d.
+      pointwise_params: 4-D Tensor of shape [1, 1, in_channels *
+        channel_multiplier, out_channels].  Filter for pointwise conv2d.
+      inputs: Tensor of shape [batch_size, input_height, input_width,
+        in_channels].  Inputs to layer.
+      depthwise_outputs: Tensor of shape [batch_size, output_height,
+        output_width, in_channels * channel_multiplier].  Output produced by
+        depthwise conv2d.
+      pointwise_outputs: Tensor of shape [batch_size, output_height,
+        output_width, out_channels].  Output produced by pointwise conv2d.
+      strides: List of ints of length 4. Strides for depthwise conv2d kernel in
+        all dimensions.
+      padding: string. see tf.nn.conv2d for valid values.
+      rate: None or List of ints of length 2. Dilation rate of depthwise conv2d
+        kernel in spatial dimensions.
+      data_format: str or None. Format of data.
+      approx: None or str. Must be "kron" if non-None.
+      reuse: bool or str.  If True, reuse an existing FisherBlock. If False,
+        create a new FisherBlock.  If "VARIABLE_SCOPE", use
+        tf.get_variable_scope().reuse.
+
+    Raises:
+      ValueError: For improper value to 'approx'.
+      KeyError: If reuse == True but no FisherBlock found for 'params'.
+      ValueError: If reuse == True and FisherBlock found but of the wrong type.
+    """
+    self.register_depthwise_conv2d(
+        params=depthwise_params,
+        inputs=inputs,
+        outputs=depthwise_outputs,
+        strides=strides,
+        padding=padding,
+        rate=rate,
+        data_format=data_format,
+        approx=APPROX_DIAGONAL_NAME,
+        reuse=reuse)
+
+    self.register_conv2d(
+        params=pointwise_params,
+        inputs=depthwise_outputs,
+        outputs=pointwise_outputs,
+        strides=[1, 1, 1, 1],
+        padding="VALID",
+        data_format=data_format,
+        approx=approx,
+        reuse=reuse)
+
   def register_generic(self,
                        params,
                        batch_size,
@@ -833,3 +1053,10 @@ class LayerCollection(object):
       with variable_scope.variable_scope(self._var_scope):
         self.fisher_factors[key] = cls(*args)
     return self.fisher_factors[key]
+
+  @contextmanager
+  def as_default(self):
+    """Sets this LayerCollection as the default."""
+    set_default_layer_collection(self)
+    yield
+    set_default_layer_collection(None)
index f8aa230..9f46853 100644 (file)
@@ -30,6 +30,8 @@ from tensorflow.python.util.all_util import remove_undocumented
 # pylint: enable=unused-import,line-too-long,wildcard-import
 
 _allowed_symbols = [
+    "get_default_layer_collection",
+    "set_default_layer_collection",
     "LayerParametersDict",
     "LayerCollection",
     "APPROX_KRONECKER_NAME",
index 5ce5338..af26f5e 100644 (file)
@@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import linalg_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
@@ -431,6 +432,127 @@ def batch_execute(global_step, thunks, batch_size, name=None):
     return result
 
 
+def extract_convolution_patches(inputs,
+                                filter_shape,
+                                padding,
+                                strides=None,
+                                dilation_rate=None,
+                                name=None,
+                                data_format=None):
+  """Extracts inputs to each output coordinate in tf.nn.convolution.
+
+  This is a generalization of tf.extract_image_patches() to tf.nn.convolution(),
+  where the number of spatial dimensions may be something other than 2.
+
+  Assumes,
+  - First dimension of inputs is batch_size
+  - Convolution filter is applied to all input channels.
+
+  Args:
+    inputs: Tensor of shape [batch_size, ..spatial_image_shape..,
+      ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution().
+    filter_shape: List of ints. Shape of filter passed to tf.nn.convolution().
+    padding: string. Padding method. One of "VALID", "SAME".
+    strides: None or list of ints. Strides along spatial dimensions.
+    dilation_rate: None or list of ints. Dilation along spatial dimensions.
+    name: None or str. Name of Op.
+    data_format: None or str. Format of data.
+
+  Returns:
+    Tensor of shape [batch_size, ..spatial_image_shape..,
+      ..spatial_filter_shape.., in_channels]
+
+  Raises:
+    ValueError: If data_format does not put channel last.
+    ValueError: If inputs and filter disagree on in_channels.
+  """
+  if not is_data_format_channel_last(data_format):
+    raise ValueError("Channel must be last dimension.")
+  with ops.name_scope(name, "extract_convolution_patches",
+                      [inputs, filter_shape, padding, strides, dilation_rate]):
+    batch_size = inputs.shape.as_list()[0]
+    in_channels = inputs.shape.as_list()[-1]
+
+    # filter_shape = spatial_filter_shape + [in_channels, out_channels]
+    spatial_filter_shape = filter_shape[:-2]
+    if in_channels != filter_shape[-2]:
+      raise ValueError("inputs and filter_shape must agree on in_channels.")
+
+    # Map each input feature to a location in the output.
+    out_channels = np.prod(spatial_filter_shape) * in_channels
+    filters = linalg_ops.eye(out_channels)
+    filters = array_ops.reshape(
+        filters,
+        list(spatial_filter_shape) + [in_channels, out_channels])
+
+    result = nn_ops.convolution(
+        inputs,
+        filters,
+        padding=padding,
+        strides=strides,
+        dilation_rate=dilation_rate)
+    spatial_output_shape = result.shape.as_list()[1:-1]
+    result = array_ops.reshape(result,
+                               [batch_size or -1] + spatial_output_shape +
+                               list(spatial_filter_shape) + [in_channels])
+
+    return result
+
+
+def extract_pointwise_conv2d_patches(inputs,
+                                     filter_shape,
+                                     name=None,
+                                     data_format=None):
+  """Extract patches for a 1x1 conv2d.
+
+  Args:
+    inputs: 4-D Tensor of shape [batch_size, height, width, in_channels].
+    filter_shape: List of 4 ints. Shape of filter to apply with conv2d()
+    name: None or str. Name for Op.
+    data_format: None or str. Format for data. See 'data_format' in
+      tf.nn.conv2d() for details.
+
+  Returns:
+    Tensor of shape [batch_size, ..spatial_input_shape..,
+    ..spatial_filter_shape.., in_channels]
+
+  Raises:
+    ValueError: if inputs is not 4-D.
+    ValueError: if filter_shape is not [1, 1, ?, ?]
+    ValueError: if data_format is not channels-last.
+  """
+  if inputs.shape.ndims != 4:
+    raise ValueError("inputs must have 4 dims.")
+  if len(filter_shape) != 4:
+    raise ValueError("filter_shape must have 4 dims.")
+  if filter_shape[0] != 1 or filter_shape[1] != 1:
+    raise ValueError("filter_shape must have shape 1 along spatial dimensions.")
+  if not is_data_format_channel_last(data_format):
+    raise ValueError("data_format must be channels last.")
+  with ops.name_scope(name, "extract_pointwise_conv2d_patches",
+                      [inputs, filter_shape]):
+    ksizes = [1, 1, 1, 1]  # Spatial shape is 1x1.
+    strides = [1, 1, 1, 1]  # Operate on all pixels.
+    rates = [1, 1, 1, 1]  # Dilation has no meaning with spatial shape = 1.
+    padding = "VALID"  # Doesn't matter.
+    result = array_ops.extract_image_patches(inputs, ksizes, strides, rates,
+                                             padding)
+
+    batch_size, input_height, input_width, in_channels = inputs.shape.as_list()
+    filter_height, filter_width, in_channels, _ = filter_shape
+    return array_ops.reshape(result, [
+        batch_size, input_height, input_width, filter_height, filter_width,
+        in_channels
+    ])
+
+
+def is_data_format_channel_last(data_format):
+  """True if data_format puts channel last."""
+  if data_format is None:
+    return True
+  return data_format.endswith("C")
+
+
 def matmul_sparse_dense(A, B, name=None):  # pylint: disable=invalid-name
   """Computes matmul(A, B) where A is sparse, B is dense.
 
index 8e424a7..330d222 100644 (file)
@@ -40,6 +40,9 @@ _allowed_symbols = [
     "fwd_gradients",
     "ensure_sequence",
     "batch_execute",
+    "extract_convolution_patches",
+    "extract_pointwise_conv2d_patches",
+    "is_data_format_channel_last",
     "matmul_sparse_dense",
     "matmul_diag_sparse",
 ]