From daedec235009dd05cdd5cb0281deb4ad6ca719c7 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Fri, 18 Jan 2019 22:55:41 -0800 Subject: [PATCH] Support ConstantOfShape in Caffe2 ONNX Backend (#16108) Summary: This PR is the prerequisite to land https://github.com/pytorch/pytorch/pull/16095 Pull Request resolved: https://github.com/pytorch/pytorch/pull/16108 Reviewed By: BIT-silence Differential Revision: D13725722 Pulled By: houseroad fbshipit-source-id: 28c0fb72f075cd04f9db44dfab0163844c20c620 --- caffe2/onnx/backend.cc | 170 +++++++++++++++++++------- caffe2/onnx/backend.h | 7 +- caffe2/operators/filler_op.h | 34 ++++-- caffe2/python/onnx/tests/onnx_backend_test.py | 3 +- 4 files changed, 158 insertions(+), 56 deletions(-) diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc index 7ea113f..abc4b68 100644 --- a/caffe2/onnx/backend.cc +++ b/caffe2/onnx/backend.cc @@ -303,6 +303,7 @@ Caffe2Backend::get_renamed_operators() const { {"Unsqueeze", "ExpandDims"}, {"Tile", "NumpyTile"}, {"DynamicSlice", "Slice"}, + {"ConstantOfShape", "ConstantFill"}, {"RandomNormal", "GaussianFill"}}; return kRenamedOperators; } @@ -340,6 +341,7 @@ Caffe2Backend::get_special_operators() const { {"ArgMin", &Caffe2Backend::CreateArgMaxMin}, {"Cast", &Caffe2Backend::CreateCast}, {"Constant", &Caffe2Backend::CreateConstant}, + {"ConstantOfShape", &Caffe2Backend::CreateConstantOfShape}, {"Conv", &Caffe2Backend::CreateConvPoolOpBase}, {"AveragePool", &Caffe2Backend::CreatePadPool}, {"GlobalAveragePool", &Caffe2Backend::CreatePadPool}, @@ -459,6 +461,20 @@ Caffe2Ops Caffe2Backend::CreateConstant( return ret; } +Caffe2Ops Caffe2Backend::CreateConstantOfShape( + OnnxNode* onnx_node, + const ConversionContext& ctx) { + CAFFE_ENFORCE_EQ(onnx_node->node.input_size(), 1); + CAFFE_ENFORCE_EQ(onnx_node->node.output_size(), 1); + + Caffe2Ops ret; + auto* c2_op = ret.ops.Add(); + const auto* value = onnx_node->attributes.get("value"); + BuildTensorFillingOp(c2_op, *value, onnx_node->node.output(0), onnx_node->node.input(0)); + + return ret; +} + // Note [Caffe2 ConvPoolOpBase] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // To understand what is going on here, we have to talk a little bit about @@ -1673,8 +1689,9 @@ void ConvertIntegralValueToCaffe2<::google::protobuf::uint64>(caffe2::OperatorDe void Caffe2Backend::BuildTensorFillingOp( caffe2::OperatorDef* c2_op, const TensorProto& onnx_tensor, - const std::string& name) { - auto fill_name = name.empty() ? onnx_tensor.name() : name; + const std::string& output_name, + const std::string& shape_name) { + auto fill_name = output_name.empty() ? onnx_tensor.name() : output_name; CAFFE_ENFORCE(!fill_name.empty()); if (onnx_tensor.has_segment()) { @@ -1682,53 +1699,118 @@ void Caffe2Backend::BuildTensorFillingOp( } auto* c2_values = c2_op->add_arg(); - c2_values->set_name("values"); - - if (onnx_tensor.data_type() == TensorProto::FLOAT) { - c2_op->set_type("GivenTensorFill"); - auto* floats = c2_values->mutable_floats(); - if (!TryConvertingTensorRawValues(onnx_tensor, floats)) { - floats->CopyFrom(onnx_tensor.float_data()); - } - } else if (onnx_tensor.data_type() == TensorProto::DOUBLE) { - c2_op->set_type("GivenTensorDoubleFill"); - ::google::protobuf::RepeatedField tmp; - const ::google::protobuf::RepeatedField* src = &tmp; - if (!TryConvertingTensorRawValues(onnx_tensor, &tmp)) { - src = &onnx_tensor.double_data(); + // if shape_name is empty, we generate GivenTensorFill + // otherwise, we generate ConstantFill, which accept shape as input + if (shape_name.empty()) { + // GivenTensor*Fill uses values + c2_values->set_name("values"); + if (onnx_tensor.data_type() == TensorProto::FLOAT) { + c2_op->set_type("GivenTensorFill"); + auto* floats = c2_values->mutable_floats(); + if (!TryConvertingTensorRawValues(onnx_tensor, floats)) { + floats->CopyFrom(onnx_tensor.float_data()); + } + } else if (onnx_tensor.data_type() == TensorProto::DOUBLE) { + c2_op->set_type("GivenTensorDoubleFill"); + ::google::protobuf::RepeatedField tmp; + const ::google::protobuf::RepeatedField* src = &tmp; + if (!TryConvertingTensorRawValues(onnx_tensor, &tmp)) { + src = &onnx_tensor.double_data(); + } + for (const auto i : *src) { + c2_values->add_floats(i); + } + } else if (onnx_tensor.data_type() == TensorProto::INT64) { + ConvertIntegralValueToCaffe2<::google::protobuf::int64>(c2_op, c2_values, onnx_tensor); + } else if (onnx_tensor.data_type() == TensorProto::UINT32) { + ConvertIntegralValueToCaffe2<::google::protobuf::uint64>(c2_op, c2_values, onnx_tensor); + } else if (onnx_tensor.data_type() == TensorProto::BOOL) { + ConvertIntegralValueToCaffe2<::google::protobuf::int8>(c2_op, c2_values, onnx_tensor); + } else if (onnx_tensor.data_type() == TensorProto::UINT8) { + ConvertIntegralValueToCaffe2<::google::protobuf::uint8>(c2_op, c2_values, onnx_tensor); + } else if (onnx_tensor.data_type() == TensorProto::INT8) { + ConvertIntegralValueToCaffe2<::google::protobuf::int8>(c2_op, c2_values, onnx_tensor); + } else if (onnx_tensor.data_type() == TensorProto::UINT16) { + ConvertIntegralValueToCaffe2<::google::protobuf::uint16>(c2_op, c2_values, onnx_tensor); + } else if (onnx_tensor.data_type() == TensorProto::INT16) { + ConvertIntegralValueToCaffe2<::google::protobuf::int16>(c2_op, c2_values, onnx_tensor); + } else if (onnx_tensor.data_type() == TensorProto::INT32) { + ConvertIntegralValueToCaffe2<::google::protobuf::int32>(c2_op, c2_values, onnx_tensor); + } else if (onnx_tensor.data_type() == TensorProto::STRING) { + c2_op->set_type("GivenTensorStringFill"); + auto* strings = c2_values->mutable_strings(); + strings->CopyFrom(onnx_tensor.string_data()); + } else { + CAFFE_THROW("unrecognized tensor type: ", onnx_tensor.data_type()); } - for (const auto i : *src) { - c2_values->add_floats(i); + auto* c2_shape = c2_op->add_arg(); + c2_shape->set_name("shape"); + for (const auto d : onnx_tensor.dims()) { + c2_shape->add_ints(d); } - } else if (onnx_tensor.data_type() == TensorProto::INT64) { - ConvertIntegralValueToCaffe2<::google::protobuf::int64>(c2_op, c2_values, onnx_tensor); - } else if (onnx_tensor.data_type() == TensorProto::UINT32) { - ConvertIntegralValueToCaffe2<::google::protobuf::uint64>(c2_op, c2_values, onnx_tensor); - } else if (onnx_tensor.data_type() == TensorProto::BOOL) { - ConvertIntegralValueToCaffe2<::google::protobuf::int8>(c2_op, c2_values, onnx_tensor); - } else if (onnx_tensor.data_type() == TensorProto::UINT8) { - ConvertIntegralValueToCaffe2<::google::protobuf::uint8>(c2_op, c2_values, onnx_tensor); - } else if (onnx_tensor.data_type() == TensorProto::INT8) { - ConvertIntegralValueToCaffe2<::google::protobuf::int8>(c2_op, c2_values, onnx_tensor); - } else if (onnx_tensor.data_type() == TensorProto::UINT16) { - ConvertIntegralValueToCaffe2<::google::protobuf::uint16>(c2_op, c2_values, onnx_tensor); - } else if (onnx_tensor.data_type() == TensorProto::INT16) { - ConvertIntegralValueToCaffe2<::google::protobuf::int16>(c2_op, c2_values, onnx_tensor); - } else if (onnx_tensor.data_type() == TensorProto::INT32) { - ConvertIntegralValueToCaffe2<::google::protobuf::int32>(c2_op, c2_values, onnx_tensor); - } else if (onnx_tensor.data_type() == TensorProto::STRING) { - c2_op->set_type("GivenTensorStringFill"); - auto* strings = c2_values->mutable_strings(); - strings->CopyFrom(onnx_tensor.string_data()); } else { - CAFFE_THROW("unrecognized tensor type: ", onnx_tensor.data_type()); + int value_size = 1; + for (const auto d : onnx_tensor.dims()) { + value_size *= d; + } + CAFFE_ENFORCE(value_size == 1); + auto c2_input_as_shape = c2_op->add_arg(); + c2_input_as_shape->set_name("input_as_shape"); + c2_input_as_shape->set_i(1); + c2_values->set_name("value"); + auto* c2_dtype = c2_op->add_arg(); + c2_dtype->set_name("dtype"); + if (onnx_tensor.data_type() == TensorProto::FLOAT) { + c2_dtype->set_i(caffe2::TensorProto::FLOAT); + if (onnx_tensor.float_data_size() > 0) { + c2_values->set_f(onnx_tensor.float_data(0)); + } else { + CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(float)); + float f; + memcpy(&f, &onnx_tensor.raw_data(), sizeof(float)); + c2_values->set_f(f); + } + } else if (onnx_tensor.data_type() == TensorProto::DOUBLE){ + c2_dtype->set_i(caffe2::TensorProto::DOUBLE); + if (onnx_tensor.double_data_size() > 0) { + c2_values->set_f(static_cast(onnx_tensor.double_data(0))); + } else { + CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(double)); + double d; + memcpy(&d, &onnx_tensor.raw_data(), sizeof(double)); + c2_values->set_f(static_cast(d)); + } + } else if (onnx_tensor.data_type() == TensorProto::INT64){ + c2_dtype->set_i(caffe2::TensorProto::INT64); + if (onnx_tensor.int64_data_size() > 0) { + c2_values->set_i(onnx_tensor.int64_data(0)); + } else { + CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(int64_t)); + int64_t i; + memcpy(&i, &onnx_tensor.raw_data(), sizeof(int64_t)); + c2_values->set_i(i); + } + } else if (onnx_tensor.data_type() == TensorProto::INT32){ + c2_dtype->set_i(caffe2::TensorProto::INT32); + if (onnx_tensor.int32_data_size() > 0) { + c2_values->set_i(onnx_tensor.int32_data(0)); + } else { + CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(int32_t)); + int32_t i; + memcpy(&i, &onnx_tensor.raw_data(), sizeof(int32_t)); + c2_values->set_i(i); + } + } else { + // TODO: to support more data type + std::stringstream oss; + oss << "Unsupported dtype: " << onnx_tensor.data_type(); + CAFFE_THROW(oss.str()); + } + // ConstantFill uses value + c2_op->set_type("ConstantFill"); + c2_op->add_input(shape_name); } - auto* c2_shape = c2_op->add_arg(); - c2_shape->set_name("shape"); - for (const auto d : onnx_tensor.dims()) { - c2_shape->add_ints(d); - } c2_op->add_output(fill_name); } diff --git a/caffe2/onnx/backend.h b/caffe2/onnx/backend.h index fa8a70c..d61af29 100644 --- a/caffe2/onnx/backend.h +++ b/caffe2/onnx/backend.h @@ -157,7 +157,8 @@ class CAFFE2_API Caffe2Backend { void BuildTensorFillingOp( caffe2::OperatorDef* c2_op, const TensorProto& onnx_tensor, - const std::string& name = ""); + const std::string& output_name = "", + const std::string& shape_name = ""); private: using SpecialOpConverter = @@ -192,6 +193,10 @@ class CAFFE2_API Caffe2Backend { Caffe2Ops CreateConstant(OnnxNode* onnx_node, const ConversionContext& ctx); + Caffe2Ops CreateConstantOfShape( + OnnxNode* onnx_node, + const ConversionContext& ctx); + Caffe2Ops CreateConvPoolOpBase( OnnxNode* onnx_node, const ConversionContext& ctx); diff --git a/caffe2/operators/filler_op.h b/caffe2/operators/filler_op.h index 3945301..4b4df15 100644 --- a/caffe2/operators/filler_op.h +++ b/caffe2/operators/filler_op.h @@ -55,15 +55,31 @@ class FillerOp : public Operator { if (InputSize()) { auto shape = vector{}; if (input_as_shape_) { - // Shape input must be in CPU context - auto& input = this->template Input(0, CPU); - CAFFE_ENFORCE_EQ( - input.dim(), - 1, - "When input_as_shape is true, the input must be a 1D tensor of " - "data type int64_t"); - auto* shape_data = input.template data(); - shape.insert(shape.end(), shape_data, shape_data + input.dim32(0)); + if (this->InputIsTensorType(0, CPU)) { + // originally, shape input must be in CPU context + auto& input = this->template Input(0, CPU); + CAFFE_ENFORCE_EQ( + input.dim(), + 1, + "When input_as_shape is true, the input must be a 1D tensor of " + "data type int64_t"); + CAFFE_ENFORCE(input.numel() > 0); + auto* shape_data = input.template data(); + shape.insert(shape.end(), shape_data, shape_data + input.dim32(0)); + } else { + // in ONNX case, we allow shape to be in CUDA context + auto& input = Input(0); + CAFFE_ENFORCE_EQ( + input.dim(), + 1, + "When input_as_shape is true, the input must be a 1D tensor of " + "data type int64_t"); + CAFFE_ENFORCE(input.numel() > 0); + auto* shape_data = input.template data(); + std::unique_ptr shape_data_copy = caffe2::make_unique(input.dim32(0)); + context_.template CopyToCPU(input.dim32(0), shape_data, shape_data_copy.get()); + shape.insert(shape.end(), shape_data_copy.get(), shape_data_copy.get() + input.dim32(0)); + } } else { auto& input = Input(0); shape.insert(shape.end(), input.sizes().begin(), input.sizes().end()); diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py index b979759..c2bf729 100644 --- a/caffe2/python/onnx/tests/onnx_backend_test.py +++ b/caffe2/python/onnx/tests/onnx_backend_test.py @@ -41,7 +41,6 @@ backend_test.exclude(r'(test_hardsigmoid' # Does not support Hardsigmoid. '|test_convtranspose.*' # ConvTranspose needs some more complicated translation '|test_mvn.*' # MeanVarianceNormalization is experimental and not supported. '|test_dynamic_slice.*' # MeanVarianceNormalization is experimental and not supported. - '|test_constantlike.*' # Needs implementation '|test_eyelike.*' # Needs implementation '|test_maxunpool.*' # Needs implementation '|test_acosh.*' # Needs implementation @@ -51,7 +50,7 @@ backend_test.exclude(r'(test_hardsigmoid' # Does not support Hardsigmoid. '|test_scan.*' # Needs implementation '|test_isnan.*' # Needs implementation '|test_scatter.*' # Should be similar to ScatterAssign - '|test_constantofshape.*' # Needs implementation + '|test_constantofshape_int.*' # Needs implementation '|test_where.*' # Needs implementation '|test_shrink.*' # Needs implementation ')') -- 2.7.4