namespace caffe2 {
namespace testing {
+// Asserts that two float values are close within epsilon.
+void assertNear(float value1, float value2, float epsilon) {
+ // These two enforces will give good debug messages.
+ CAFFE_ENFORCE_LE(value1, value2 + epsilon);
+ CAFFE_ENFORCE_GE(value1, value2 - epsilon);
+}
+
void assertTensorEquals(const TensorCPU& tensor1, const TensorCPU& tensor2) {
CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
if (tensor1.IsType<float>()) {
const caffe2::Tensor& getTensor(
const caffe2::Workspace& workspace,
const std::string& name) {
+ CAFFE_ENFORCE(workspace.HasBlob(name));
return workspace.GetBlob(name)->Get<caffe2::Tensor>();
}
#include "caffe2/core/tensor.h"
#include "caffe2/core/workspace.h"
+#include <cmath>
+#include <vector>
+
// Utilities that make it easier to write caffe2 C++ unit tests.
// These utils are designed to be concise and easy to use. They may sacrifice
// performance and should only be used in tests/non production code.
// Asserts that the values of two tensors are the same.
void assertTensorEquals(const TensorCPU& tensor1, const TensorCPU& tensor2);
+// Asserts that two float values are close within epsilon.
+void assertNear(float value1, float value2, float epsilon);
+
+// Asserts that the numeric values of a tensor is equal to a data vector.
+template <typename T>
+void assertTensorEquals(
+ const TensorCPU& tensor,
+ const std::vector<T>& data,
+ float epsilon = 0.1f) {
+ CAFFE_ENFORCE(tensor.IsType<T>());
+ CAFFE_ENFORCE_EQ(tensor.numel(), data.size());
+ for (auto idx = 0; idx < tensor.numel(); ++idx) {
+ if (tensor.IsType<float>()) {
+ assertNear(tensor.data<T>()[idx], data[idx], epsilon);
+ } else {
+ CAFFE_ENFORCE_EQ(tensor.data<T>()[idx], data[idx]);
+ }
+ }
+}
+
+// Assertion for tensor sizes and values.
+template <typename T>
+void assertTensor(
+ const TensorCPU& tensor,
+ const std::vector<int64_t>& sizes,
+ const std::vector<T>& data,
+ float epsilon = 0.1f) {
+ CAFFE_ENFORCE_EQ(tensor.sizes(), sizes);
+ assertTensorEquals(tensor, data, epsilon);
+}
+
// Asserts a list of tensors presented in two workspaces are equal.
void assertTensorListEquals(
const std::vector<std::string>& tensorNames,
template <typename T>
WorkspaceMutator& newTensorConst(
const string& name,
- const vector<int64_t>& shape,
+ const std::vector<int64_t>& shape,
const T& data) {
createTensorAndConstantFill<T>(name, shape, data, workspace_);
return *this;