move eye to linalg_ops_impl
authorWenhao Hu <fumihwh@gmail.com>
Sun, 15 Apr 2018 01:26:01 +0000 (10:26 +0900)
committerWenhao Hu <fumihwh@gmail.com>
Sun, 15 Apr 2018 01:26:01 +0000 (10:26 +0900)
tensorflow/python/BUILD
tensorflow/python/ops/init_ops.py
tensorflow/python/ops/linalg/linalg.py
tensorflow/python/ops/linalg/linalg_impl.py
tensorflow/python/ops/linalg_ops.py
tensorflow/python/ops/linalg_ops_impl.py [new file with mode: 0644]

index 0cd3f27..1225786 100644 (file)
@@ -1934,7 +1934,8 @@ py_library(
         ":array_ops",
         ":constant_op",
         ":dtypes",
-        ":linalg_ops",
+        ":linalg_ops_gen",
+        ":linalg_ops_impl",
         ":math_ops",
         ":nn_ops",
         ":random_ops",
@@ -1971,7 +1972,6 @@ py_library(
         ":array_ops",
         ":control_flow_ops",
         ":framework_for_generated_wrappers",
-        ":functional_ops",
         ":linalg_ops",
         ":math_ops",
         "//tensorflow/python/ops/linalg:linalg_impl",
@@ -1986,7 +1986,22 @@ py_library(
         ":array_ops",
         ":dtypes",
         ":framework_ops",
+        ":functional_ops",
         ":linalg_ops_gen",
+        ":linalg_ops_impl",
+        ":math_ops",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_library(
+    name = "linalg_ops_impl",
+    srcs = ["ops/linalg_ops_impl.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":array_ops",
+        ":dtypes",
+        ":framework_ops",
         ":math_ops",
         "//third_party/py/numpy",
     ],
index 9dfe5ff..366a72c 100644 (file)
@@ -39,7 +39,8 @@ import numpy as np
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import linalg_ops_impl
+from tensorflow.python.ops import gen_linalg_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import random_ops
@@ -529,7 +530,7 @@ class Orthogonal(Initializer):
     # Generate a random matrix
     a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
     # Compute the qr factorization
-    q, r = linalg_ops.qr(a, full_matrices=False)
+    q, r = gen_linalg_ops.qr(a, full_matrices=False)
     # Make Q uniform
     d = array_ops.diag_part(r)
     q *= math_ops.sign(d)
@@ -578,7 +579,7 @@ class ConvolutionDeltaOrthogonal(Initializer):
     a = random_ops.random_normal([shape[-1], shape[-1]],
                                  dtype=dtype, seed=self.seed)
     # Compute the qr factorization
-    q, r = linalg_ops.qr(a, full_matrices=False)
+    q, r = gen_linalg_ops.qr(a, full_matrices=False)
     # Make Q uniform
     d = array_ops.diag_part(r)
     # ph = d / math_ops.abs(d)
@@ -623,7 +624,7 @@ class Identity(Initializer):
           "Identity matrix initializer can only be used for 2D matrices.")
     if dtype is None:
       dtype = self.dtype
-    initializer = linalg_ops.eye(*full_shape, dtype=dtype)
+    initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype)
     if partition_info is not None:
       initializer = array_ops.slice(initializer, partition_info.var_offset,
                                     shape)
index 1431902..7e9c3cd 100644 (file)
@@ -39,6 +39,7 @@ del ops
 del array_ops
 del gen_linalg_ops
 del linalg_ops
+del linalg_ops_impl
 del math_ops
 del special_math_ops
 del tf_export
index 8343c62..6b1a046 100644 (file)
@@ -22,6 +22,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_linalg_ops
 from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import linalg_ops_impl
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import special_math_ops
 from tensorflow.python.util.tf_export import tf_export
@@ -40,7 +41,7 @@ eigvalsh = linalg_ops.self_adjoint_eigvals
 einsum = special_math_ops.einsum
 expm = gen_linalg_ops.matrix_exponential
 tf_export('linalg.expm')(expm)
-eye = linalg_ops.eye
+eye = linalg_ops_impl.eye
 inv = linalg_ops.matrix_inverse
 logm = gen_linalg_ops.matrix_logarithm
 tf_export('linalg.logm')(logm)
index 50706e5..805fbd9 100644 (file)
@@ -26,6 +26,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import gen_linalg_ops
+from tensorflow.python.ops import linalg_ops_impl
 from tensorflow.python.ops import math_ops
 # pylint: disable=wildcard-import
 from tensorflow.python.ops.gen_linalg_ops import *
@@ -160,36 +161,11 @@ def eye(num_rows,
   Returns:
     A `Tensor` of shape `batch_shape + [num_rows, num_columns]`
   """
-  with ops.name_scope(
-      name, default_name='eye', values=[num_rows, num_columns, batch_shape]):
-    is_square = num_columns is None
-    batch_shape = [] if batch_shape is None else batch_shape
-    num_columns = num_rows if num_columns is None else num_columns
-    if isinstance(num_rows, ops.Tensor) or isinstance(
-        num_columns, ops.Tensor) or isinstance(batch_shape, ops.Tensor):
-      batch_shape = ops.convert_to_tensor(
-          batch_shape, name='shape', dtype=dtypes.int32)
-      diag_size = math_ops.minimum(num_rows, num_columns)
-      diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
-      if not is_square:
-        shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
-    else:
-      if not isinstance(num_rows, compat.integral_types) or not isinstance(
-          num_columns, compat.integral_types):
-        raise TypeError(
-            'num_rows and num_columns must be positive integer values.')
-      batch_shape = [dim for dim in batch_shape]
-      is_square = num_rows == num_columns
-      diag_shape = batch_shape + [np.minimum(num_rows, num_columns)]
-      if not is_square:
-        shape = batch_shape + [num_rows, num_columns]
-
-    diag_ones = array_ops.ones(diag_shape, dtype=dtype)
-    if is_square:
-      return array_ops.matrix_diag(diag_ones)
-    else:
-      zero_matrix = array_ops.zeros(shape, dtype=dtype)
-      return array_ops.matrix_set_diag(zero_matrix, diag_ones)
+  return linalg_ops_impl.eye(num_rows,
+                             num_columns=num_columns,
+                             batch_shape=batch_shape,
+                             dtype=dtype,
+                             name=name)
 
 
 @tf_export('matrix_solve_ls', 'linalg.lstsq')
diff --git a/tensorflow/python/ops/linalg_ops_impl.py b/tensorflow/python/ops/linalg_ops_impl.py
new file mode 100644 (file)
index 0000000..9263b95
--- /dev/null
@@ -0,0 +1,73 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operations for linear algebra."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util import compat
+
+# Names below are lower_case.
+# pylint: disable=invalid-name
+
+
+def eye(num_rows,
+        num_columns=None,
+        batch_shape=None,
+        dtype=dtypes.float32,
+        name=None):
+  """Construct an identity matrix, or a batch of matrices.
+
+  See `linalg_ops.eye`.
+  """
+  with ops.name_scope(
+      name, default_name='eye', values=[num_rows, num_columns, batch_shape]):
+    is_square = num_columns is None
+    batch_shape = [] if batch_shape is None else batch_shape
+    num_columns = num_rows if num_columns is None else num_columns
+    if isinstance(num_rows, ops.Tensor) or isinstance(
+        num_columns, ops.Tensor) or isinstance(batch_shape, ops.Tensor):
+      batch_shape = ops.convert_to_tensor(
+          batch_shape, name='shape', dtype=dtypes.int32)
+      diag_size = math_ops.minimum(num_rows, num_columns)
+      diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
+      if not is_square:
+        shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
+    else:
+      if not isinstance(num_rows, compat.integral_types) or not isinstance(
+          num_columns, compat.integral_types):
+        raise TypeError(
+            'num_rows and num_columns must be positive integer values.')
+      batch_shape = [dim for dim in batch_shape]
+      is_square = num_rows == num_columns
+      diag_shape = batch_shape + [np.minimum(num_rows, num_columns)]
+      if not is_square:
+        shape = batch_shape + [num_rows, num_columns]
+
+    diag_ones = array_ops.ones(diag_shape, dtype=dtype)
+    if is_square:
+      return array_ops.matrix_diag(diag_ones)
+    else:
+      zero_matrix = array_ops.zeros(shape, dtype=dtype)
+      return array_ops.matrix_set_diag(zero_matrix, diag_ones)
+
+# pylint: enable=invalid-name,redefined-builtin