"ops/parsing_ops_test.cc",
"ops/random_ops_test.cc",
"ops/set_ops_test.cc",
+ "ops/shape_function_test.cc",
"ops/sparse_ops_test.cc",
"ops/spectral_ops_test.cc",
"ops/state_ops_test.cc",
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
REGISTER_KERNEL_BUILDER(Name("FindDeviceOp").Device(tensorflow::DEVICE_CPU),
FindDeviceOpKernel);
-REGISTER_OP("FindDeviceOp").Output("device_name: string");
+REGISTER_OP("FindDeviceOp")
+ .Output("device_name: string")
+ .SetShapeFn(shape_inference::UnknownShape);
FunctionDef FindDevice() {
return FDH::Define(
}
}
+void OpRegistry::GetOpRegistrationData(
+ std::vector<OpRegistrationData>* op_data) {
+ mutex_lock lock(mu_);
+ MustCallDeferred();
+ for (const auto& p : registry_) {
+ op_data->push_back(*p.second);
+ }
+}
+
Status OpRegistry::SetWatcher(const Watcher& watcher) {
mutex_lock lock(mu_);
if (watcher_ && watcher) {
// Get all registered ops.
void GetRegisteredOps(std::vector<OpDef>* op_defs);
+ // Get all `OpRegistrationData`s.
+ void GetOpRegistrationData(std::vector<OpRegistrationData>* op_data);
+
// Watcher, a function object.
// The watcher, if set by SetWatcher(), is called every time an op is
// registered via the Register function. The watcher is passed the Status
#include "tensorflow/core/graph/testlib.h"
#include <vector>
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
REGISTER_OP("HostConst")
.Output("output: dtype")
.Attr("value: tensor")
- .Attr("dtype: type");
+ .Attr("dtype: type")
+ .SetShapeFn(shape_inference::UnknownShape);
namespace test {
namespace graph {
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
.Attr("Tin: list(type)")
.Attr("T: type")
.Attr("N: int >= 1")
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Converts a list of tensors to an array of tensors.
)doc");
.Attr("T: type")
.Attr("N: int >= 1")
.Attr("out_types: list(type)")
+ .SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
Converts an array of tensors to a list of tensors.
)doc");
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/platform/test.h"
+
+// Test to ensure that all core ops have shape functions defined. This is done
+// by looking at all ops registered in the test binary.
+
+namespace tensorflow {
+
+TEST(ShapeFunctionTest, RegisteredOpsHaveShapeFns) {
+ OpRegistry* op_registry = OpRegistry::Global();
+ std::vector<OpRegistrationData> op_data;
+ op_registry->GetOpRegistrationData(&op_data);
+ for (const OpRegistrationData& op_reg_data : op_data) {
+ EXPECT_TRUE(op_reg_data.shape_inference_fn != nullptr)
+ << op_reg_data.op_def.name();
+ }
+}
+
+} // namespace tensorflow
REGISTER_OP("BatchFFT")
.Input("input: complex64")
.Output("output: complex64")
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(15, "Use FFT");
REGISTER_OP("BatchIFFT")
.Input("input: complex64")
.Output("output: complex64")
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(15, "Use IFFT");
REGISTER_OP("BatchFFT2D")
.Input("input: complex64")
.Output("output: complex64")
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(15, "Use FFT2D");
REGISTER_OP("BatchIFFT2D")
.Input("input: complex64")
.Output("output: complex64")
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(15, "Use IFFT2D");
REGISTER_OP("BatchFFT3D")
.Input("input: complex64")
.Output("output: complex64")
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(15, "Use FFT3D");
REGISTER_OP("BatchIFFT3D")
.Input("input: complex64")
.Output("output: complex64")
+ .SetShapeFn(shape_inference::UnknownShape)
.Deprecated(15, "Use IFFT3D");
} // namespace tensorflow
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
.Attr("batch_size: int")
.Attr("window_size: int = 5")
.Attr("min_count: int = 5")
- .Attr("subsample: float = 1e-3");
+ .Attr("subsample: float = 1e-3")
+ .SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("NegTrain")
.Deprecated(19,
.Input("lr: float")
.SetIsStateful()
.Attr("vocab_count: list(int)")
- .Attr("num_negative_samples: int");
+ .Attr("num_negative_samples: int")
+ .SetShapeFn(shape_inference::UnknownShape);
} // end namespace tensorflow
// An example Op.
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
-REGISTER_OP("Fact").Output("fact: string");
+REGISTER_OP("Fact")
+ .Output("fact: string")
+ .SetShapeFn(tensorflow::shape_inference::UnknownShape);
class FactOp : public tensorflow::OpKernel {
public: