From fabd23cb2da00d6c0f033ba22a07106302213fb7 Mon Sep 17 00:00:00 2001 From: Xianjie Chen Date: Wed, 12 Dec 2018 21:31:14 -0800 Subject: [PATCH] support casting to string (#15110) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15110 support casting to string on CPU Reviewed By: intermilan Differential Revision: D13429381 fbshipit-source-id: b737a1ba1237b10f692d5c42b42a544b94ba9fd1 --- caffe2/operators/cast_op.cc | 25 ++++++++++++++++++++----- caffe2/python/operator_test/cast_op_test.py | 18 +++++++++++++++++- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/caffe2/operators/cast_op.cc b/caffe2/operators/cast_op.cc index 70da41a..82b6355 100644 --- a/caffe2/operators/cast_op.cc +++ b/caffe2/operators/cast_op.cc @@ -2,6 +2,20 @@ namespace caffe2 { +template +struct CastHelper { + static DstType call(SrcType data) { + return static_cast(data); + } +}; + +template +struct CastHelper { + static std::string call(SrcType data) { + return caffe2::to_string(data); + } +}; + template <> template bool CastOp::DoRunWithType() { @@ -12,7 +26,7 @@ bool CastOp::DoRunWithType() { auto* out = output->template mutable_data(); auto N = input.numel(); for (int64_t i = 0; i < N; ++i) { - out[i] = static_cast(data[i]); + out[i] = CastHelper::call(data[i]); } return true; } @@ -31,8 +45,8 @@ void CastOp::SetBody(TensorProto_DataType to) { LOG(FATAL) << "BYTE is deprecated"; break; case TensorProto_DataType_STRING: - CAFFE_THROW("Casting to and from strings is not supported yet"); - // break; + body_ = &CastOp::DoRunWithDstType; + break; case TensorProto_DataType_BOOL: body_ = &CastOp::DoRunWithDstType; break; @@ -55,7 +69,7 @@ void CastOp::SetBody(TensorProto_DataType to) { CAFFE_THROW("Casting to and from at::Half on CPU is not supported yet"); // break; case TensorProto_DataType_DOUBLE: - //body_ = &CastOp::DoRunIncFp16WithDstType; + // body_ = &CastOp::DoRunIncFp16WithDstType; body_ = &CastOp::DoRunWithDstType; break; case TensorProto_DataType_UNDEFINED: @@ -104,7 +118,8 @@ enum field in the TensorProto message (see below). If the `to` argument is not provided or is not one of the enumerated types in *DataType*, Caffe2 throws an Enforce error. -NOTE: Casting to and from strings is not supported yet. +NOTE: Casting from strings is not supported, and casting to strings is only +supported on CPU. TensorProto *DataType* field: ``` diff --git a/caffe2/python/operator_test/cast_op_test.py b/caffe2/python/operator_test/cast_op_test.py index 39195c8..f7ffb5b 100644 --- a/caffe2/python/operator_test/cast_op_test.py +++ b/caffe2/python/operator_test/cast_op_test.py @@ -4,7 +4,7 @@ from __future__ import print_function from __future__ import unicode_literals -from caffe2.python import core +from caffe2.python import core, workspace import caffe2.python.hypothesis_test_util as hu from hypothesis import given @@ -30,3 +30,19 @@ class TestCastOp(hu.HypothesisTestCase): self.assertDeviceChecks(dc, op, [data], [0]) # This is actually 0 self.assertGradientChecks(gc, op, [data], 0, [0]) + + @given(data=hu.tensor(dtype=np.int32), **hu.gcs_cpu_only) + def test_cast_int_to_string(self, data, gc, dc): + op = core.CreateOperator( + 'Cast', 'data', 'data_cast', to=core.DataType.STRING) + + def ref(data): + ret = data.astype(dtype=np.str) + # the string blob will be fetched as object, we feed and re-fetch + # to mimic this. + with hu.temp_workspace('tmp_ref_int_to_string'): + workspace.FeedBlob('tmp_blob', ret) + fetched_ret = workspace.FetchBlob('tmp_blob') + return (fetched_ret, ) + + self.assertReferenceChecks(gc, op, inputs=[data], reference=ref) -- 2.7.4