#include <caffe2/ideep/ideep_utils.h>
+#include <caffe2/ideep/operators/operator_fallback_ideep.h>
+#include <caffe2/operators/concat_split_op.h>
namespace caffe2 {
public:
USE_IDEEP_DEF_ALIASES();
USE_IDEEP_OPERATOR_FUNCTIONS();
+ using FALLBACK_OP = IDEEPFallbackOp<ConcatOp<CPUContext>, SkipIndices<0>>;
IDEEPConcatOp(const OperatorDef& operator_def, Workspace* ws)
- : IDEEPOperator(operator_def, ws) {
+ : IDEEPOperator(operator_def, ws),
+ fallback_(operator_def, ws) {
CAFFE_ENFORCE(
!(OperatorBase::HasArgument("axis") && OperatorBase::HasArgument("order")),
"You shouldn't specify both the dim to concat, and the order "
virtual ~IDEEPConcatOp() {}
bool RunOnDevice() override {
- auto* output = Output(OUTPUT);
+ bool fallback_to_cpu = false;
+ vector<itensor> inputs_itensor;
- vector<itensor> inputs;
for (int i = 0; i < InputSize(); ++i) {
if (OperatorBase::InputBlob(i).template IsType<itensor>()) {
- inputs.emplace_back(Input(i));
+ auto& tensor_ideep = Input(i);
+ if (tensor_ideep.ndims() == 0 || tensor_ideep.get_nelems() == 0)
+ continue;
+ inputs_itensor.emplace_back(tensor_ideep);
} else {
CAFFE_ENFORCE(
BlobIsTensorType(OperatorBase::InputBlob(i), CPU),
"Expect cpu tensor if not itensor");
auto& tensor_cpu = OperatorBase::Input<Tensor>(i, CPU);
- CAFFE_ENFORCE(
- tensor_cpu.sizes().size() == 0 || tensor_cpu.numel() == 0,
- "Expect zero dim tensor");
+ if (tensor_cpu.sizes().size() == 0 || tensor_cpu.numel() == 0)
+ continue;
+ fallback_to_cpu = true;
+ break;
}
}
- auto axis_vdata = ideep::concat::compute(inputs, axis_, add_axis_, *output);
- Tensor* axis_info = OutputTensor(
- AXIS_INFO,
- vector<int64_t>(1, InputSize()),
- at::dtype<int>().device(CPU));
- int* axis_data = axis_info->template mutable_data<int>();
- for (int i = 0; i < axis_vdata.size(); i++) {
- axis_data[i] = axis_vdata[i];
+ if (!fallback_to_cpu) {
+ auto* output = Output(OUTPUT);
+ Tensor* axis_info = OutputTensor(AXIS_INFO,
+ vector<int64_t>(1, InputSize()), at::dtype<int>().device(CPU));
+ auto* axis_data = axis_info->template mutable_data<int>();
+ auto axis_vdata =
+ ideep::concat::compute(inputs_itensor, axis_, add_axis_, *output);
+ for (int i = 0; i < axis_vdata.size(); i++) {
+ axis_data[i] = axis_vdata[i];
+ }
+ return true;
}
- return true;
+ return fallback_.Run(0);
}
private:
int axis_;
int add_axis_;
+ FALLBACK_OP fallback_;
INPUT_TAGS(INPUT0);
OUTPUT_TAGS(OUTPUT, AXIS_INFO);
self.assertGradientChecks(gc, op, splits, i, [0])
+ @given(tensor_splits=_tensor_splits(add_axis=True), **mu.gcs)
+ def test_concat_with_TensorCPU(self, tensor_splits, gc, dc):
+ axis, _, splits = tensor_splits
+ op0 = core.CreateOperator(
+ "Concat",
+ ['X_{}'.format(i) for i in range(len(splits))],
+ ['concat_result0', 'split_info0'],
+ axis=axis,
+ add_axis=1,
+ device_option=dc[0]
+ )
+ op1 = core.CreateOperator(
+ "Concat",
+ ['X_{}'.format(i) for i in range(len(splits))],
+ ['concat_result1', 'split_info1'],
+ axis=axis,
+ add_axis=1,
+ device_option=dc[1]
+ )
+
+ for i, X in enumerate(splits):
+ workspace.FeedBlob('X_{}'.format(i), X, dc[0])
+
+ workspace.RunOperatorOnce(op0)
+ res0 = workspace.FetchBlob('concat_result0')
+ inf0 = workspace.FetchBlob('split_info0')
+
+ workspace.RunOperatorOnce(op1)
+ res1 = workspace.FetchBlob('concat_result1')
+ inf1 = workspace.FetchBlob('split_info1')
+
+ if not np.allclose(res0, res1, atol=0.0, rtol=0.0):
+ print(res1.flatten())
+ print(res0.flatten())
+ print(np.max(np.abs(res1 - res0)))
+ self.assertTrue(False)
+
+ if not np.allclose(inf0, inf1, atol=0.0, rtol=0.0):
+ print(inf1.flatten())
+ print(inf0.flatten())
+ print(np.max(np.abs(inf1 - inf0)))
+ self.assertTrue(False)
+
if __name__ == "__main__":
unittest.main()