From aa50f1e36517ceccde968d76f2f499165c97fdee Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Thu, 18 Apr 2019 23:56:32 -0700 Subject: [PATCH] Clean the onnx constant fold code a bit (#19398) Summary: This is a follow up PR of https://github.com/pytorch/pytorch/pull/18698 to lint the code using clang-format. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19398 Differential Revision: D14994517 Pulled By: houseroad fbshipit-source-id: 2ae9f93e66ce66892a1edc9543ea03932cd82bee --- torch/csrc/jit/passes/onnx/constant_fold.cpp | 178 +++++++++++++++------------ torch/csrc/jit/passes/onnx/constant_fold.h | 2 + torch/csrc/jit/passes/onnx/peephole.h | 2 + 3 files changed, 100 insertions(+), 82 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index 30b82d7..497c6c8 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include namespace torch { namespace jit { @@ -14,58 +14,63 @@ using namespace ::c10::onnx; namespace { using ParamMap = std::map; -using ValueToParamPairMap = std::map>; - -std::unordered_map onnxTypeToScalarTypeMap = -{ - // Only conversion of ONNX numeric types is included here. - // Unsigned ONNX types are mapped to the next higher signed - // ScalarType type. - {1, at::kFloat}, - {2, at::kByte}, - {3, at::kChar}, - {4, at::kInt}, - {5, at::kShort}, - {6, at::kInt}, - {7, at::kLong}, - {10, at::kFloat}, - {11, at::kDouble}, - {12, at::kLong}, +using ValueToParamPairMap = + std::map>; + +std::unordered_map onnxTypeToScalarTypeMap = { + // Only conversion of ONNX numeric types is included here. + // Unsigned ONNX types are mapped to the next higher signed + // ScalarType type. + {1, at::kFloat}, + {2, at::kByte}, + {3, at::kChar}, + {4, at::kInt}, + {5, at::kShort}, + {6, at::kInt}, + {7, at::kLong}, + {10, at::kFloat}, + {11, at::kDouble}, + {12, at::kLong}, }; -ValueToParamPairMap buildValueToParamsMap(Block* b, const ParamMap& paramsDict) { +ValueToParamPairMap buildValueToParamsMap( + Block* b, + const ParamMap& paramsDict) { ValueToParamPairMap valsToParamsMap; - for(auto& input : b->inputs()) { - auto it = paramsDict.find(input->uniqueName()); - if (it != paramsDict.end()) { - valsToParamsMap[input] = *it; - } + for (auto& input : b->inputs()) { + auto it = paramsDict.find(input->uniqueName()); + if (it != paramsDict.end()) { + valsToParamsMap.emplace(input, *it); + } } return valsToParamsMap; } void buildParamsMapFromValueToParamsMap( - const ValueToParamPairMap& valsToParamsMap, ParamMap& paramsDict) { + const ValueToParamPairMap& valsToParamsMap, + ParamMap& paramsDict) { paramsDict.clear(); - for(auto& nameTensorParamPair : valsToParamsMap) { + for (const auto& nameTensorParamPair : valsToParamsMap) { paramsDict.insert(nameTensorParamPair.second); } } void eraseUnusedBlockInputs(Block* b) { for (size_t i_1 = b->inputs().size(); i_1 > 0; --i_1) { - size_t i = i_1 - 1; - if (!b->inputs().at(i)->hasUses()) { - b->eraseInput(i); - } + size_t i = i_1 - 1; + if (!b->inputs().at(i)->hasUses()) { + b->eraseInput(i); + } } } -c10::optional runTorchBackendForOnnx(const Node* node, std::vector& inputTensorValues) { +c10::optional runTorchBackendForOnnx( + const Node* node, + std::vector& inputTensorValues) { at::Tensor updated_val; if (node->kind() == onnx::Slice) { assert(inputTensorValues.size() == 1); - if ( !(node->hasAttributeS("starts") && node->hasAttributeS("ends")) ) { + if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) { return c10::nullopt; } auto startsAttr = node->is(attr::starts); @@ -76,82 +81,85 @@ c10::optional runTorchBackendForOnnx(const Node* node, std::vector axesAttr; if (node->hasAttributeS("axes")) { axesAttr = node->is(attr::axes); - } - else { + } else { axesAttr.resize(startsAttr.size()); std::iota(axesAttr.begin(), axesAttr.end(), 0); } updated_val = inputTensorValues[0]; for (size_t i = 0; i < axesAttr.size(); ++i) { - updated_val = at::narrow(updated_val, axesAttr[i], startsAttr[i], endsAttr[i] - startsAttr[i]); + updated_val = at::narrow( + updated_val, axesAttr[i], startsAttr[i], endsAttr[i] - startsAttr[i]); } return c10::optional(updated_val); - } - else if (node->kind() == onnx::Concat) { + } else if (node->kind() == onnx::Concat) { if (!node->hasAttributeS("axis")) { return c10::nullopt; } - updated_val = at::cat(at::TensorList(inputTensorValues), node->i(attr::axis)); + updated_val = + at::cat(at::TensorList(inputTensorValues), node->i(attr::axis)); return c10::optional(updated_val); - } - else if (node->kind() == onnx::Unsqueeze) { + } else if (node->kind() == onnx::Unsqueeze) { assert(inputTensorValues.size() == 1); if (!node->hasAttributeS("axes")) { return c10::nullopt; } updated_val = inputTensorValues[0]; - for (auto axis: node->is(attr::axes)) { + for (auto axis : node->is(attr::axes)) { updated_val = at::unsqueeze(updated_val, axis); } return c10::optional(updated_val); - } - else if (node->kind() == onnx::Transpose) { + } else if (node->kind() == onnx::Transpose) { assert(inputTensorValues.size() == 1); if (!node->hasAttributeS("perm")) { return c10::nullopt; } updated_val = inputTensorValues[0].permute(node->is(attr::perm)); return c10::optional(updated_val); - } - else if (node->kind() == onnx::Cast) { + } else if (node->kind() == onnx::Cast) { assert(inputTensorValues.size() == 1); if (node->hasAttributeS("to") && - onnxTypeToScalarTypeMap.find(node->i(attr::to)) != onnxTypeToScalarTypeMap.end()) { - updated_val = inputTensorValues[0].to(onnxTypeToScalarTypeMap[node->i(attr::to)]); + onnxTypeToScalarTypeMap.find(node->i(attr::to)) != + onnxTypeToScalarTypeMap.end()) { + updated_val = + inputTensorValues[0].to(onnxTypeToScalarTypeMap[node->i(attr::to)]); return c10::optional(updated_val); } return c10::nullopt; - } - else { + } else { return c10::nullopt; } } bool isConstant(Value* val, const ValueToParamPairMap& valsToParamsMap) { auto parentNode = val->node(); - return (parentNode->kind() == prim::Param && - valsToParamsMap.find(val) != valsToParamsMap.end()) || // Checks val is a parameter and not a real input - (parentNode->kind() == onnx::Constant && !parentNode->mustBeNone() && - parentNode->kindOf(attr::value) == AttributeKind::t); // Check other types? + return (parentNode->kind() == prim::Param && + valsToParamsMap.find(val) != + valsToParamsMap + .end()) || // Checks val is a parameter and not a real input + (parentNode->kind() == onnx::Constant && !parentNode->mustBeNone() && + parentNode->kindOf(attr::value) == + AttributeKind::t); // Check other types? } -std::vector getValues(Node* node, const ValueToParamPairMap& valsToParamsMap) { +std::vector getValues( + Node* node, + const ValueToParamPairMap& valsToParamsMap) { size_t numInputs = node->inputs().size(); std::vector inputTensorValues; inputTensorValues.reserve(numInputs); for (auto val : node->inputs()) { if (val->node()->kind() == prim::Param) { auto itr = valsToParamsMap.find(val); - if(itr == valsToParamsMap.end()) { - throw std::runtime_error("getValues: Input value not found amongst constant parameters."); + if (itr == valsToParamsMap.end()) { + throw std::runtime_error( + "getValues: Input value not found amongst constant parameters."); } inputTensorValues.push_back(itr->second.second); - } - else if (val->node()->kind() == onnx::Constant) { + } else if (val->node()->kind() == onnx::Constant) { inputTensorValues.push_back(val->node()->t(attr::value)); - } - else { - throw std::runtime_error("getValues: Unsupported kind of constant node found."); + } else { + throw std::runtime_error( + "getValues: Unsupported kind of constant node found."); } } AT_ASSERT(inputTensorValues.size() == numInputs); @@ -163,16 +171,19 @@ void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap) { while (it != valsToParamsMap.end()) { if (!it->first->hasUses()) { it = valsToParamsMap.erase(it); - } - else { + } else { ++it; } } } -bool areNodeInputsConstant(Node* node, const ValueToParamPairMap& valsToParamsMap) { - return std::all_of(node->inputs().begin(), node->inputs().end(), - [&valsToParamsMap](Value* v) { return isConstant(v, valsToParamsMap); }); +bool areNodeInputsConstant( + Node* node, + const ValueToParamPairMap& valsToParamsMap) { + return std::all_of( + node->inputs().begin(), + node->inputs().end(), + [&valsToParamsMap](Value* v) { return isConstant(v, valsToParamsMap); }); } std::vector getOnnxConstParents(Node* node) { @@ -187,7 +198,7 @@ std::vector getOnnxConstParents(Node* node) { } // Anonymous namespace -// This method updates the block in-place to fold all the one-time +// This method updates the block in-place to fold all the one-time // constant-based computations/ops into an initializer node. void ConstantFoldONNX(Block* b, ParamMap& paramsDict) { AT_ASSERT(b->param_node()); @@ -197,31 +208,33 @@ void ConstantFoldONNX(Block* b, ParamMap& paramsDict) { for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { auto node = *it; if (node->outputs().size() > 1) { - // Constant folding for multiple-output nodes not supported. Skip it. - continue; - } + // Constant folding for multiple-output nodes not supported. Skip it. + continue; + } if (!areNodeInputsConstant(node, valsToParamsMap)) { - // If all the inputs to this node are not either parameter or - // onnx::Constant, then skip this node. - continue; + // If all the inputs to this node are not either parameter or + // onnx::Constant, then skip this node. + continue; } auto inputTensorValues = getValues(node, valsToParamsMap); if (inputTensorValues.empty()) { - // This is a terminal node with no inputs, such as onnx::Constant. Skip it. - continue; + // This is a terminal node with no inputs, such as onnx::Constant. Skip + // it. + continue; } auto updatedValWrapped = runTorchBackendForOnnx(node, inputTensorValues); if (updatedValWrapped == c10::nullopt) { - // Constant folding is not supported for this op. Skip it. - continue; - } + // Constant folding is not supported for this op. Skip it. + continue; + } // Create a new input to the block (prim::Param node output). Add a // corresponding entryin valToParamMap. Replace the downstream inputs // with this value, and disconnect all the input values of the folded node. at::Tensor updatedVal = *updatedValWrapped; auto newSourceNodeOutput = b->addInput(); - valsToParamsMap.insert({newSourceNodeOutput, - std::make_pair(newSourceNodeOutput->uniqueName(), updatedVal)}); + valsToParamsMap.insert( + {newSourceNodeOutput, + std::make_pair(newSourceNodeOutput->uniqueName(), updatedVal)}); newSourceNodeOutput->inferTypeFrom(updatedVal); node->outputs().at(0)->replaceAllUsesWith(newSourceNodeOutput); @@ -234,8 +247,9 @@ void ConstantFoldONNX(Block* b, ParamMap& paramsDict) { // by eraseUnusedBlockInputs() call (below) outside the loop. auto onnxConstParents = getOnnxConstParents(node); node->removeAllInputs(); - std::for_each(onnxConstParents.begin(), onnxConstParents.end(), - [](Node* n){ n->destroy(); }); + for (auto* n : onnxConstParents) { + n->destroy(); + } it.destroyCurrent(); } eraseUnusedValuesFromMap(valsToParamsMap); diff --git a/torch/csrc/jit/passes/onnx/constant_fold.h b/torch/csrc/jit/passes/onnx/constant_fold.h index b6461b0..2eb2146 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.h +++ b/torch/csrc/jit/passes/onnx/constant_fold.h @@ -1,3 +1,5 @@ +#pragma once + #include namespace torch { diff --git a/torch/csrc/jit/passes/onnx/peephole.h b/torch/csrc/jit/passes/onnx/peephole.h index 63e9132..f06cb6b 100644 --- a/torch/csrc/jit/passes/onnx/peephole.h +++ b/torch/csrc/jit/passes/onnx/peephole.h @@ -1,3 +1,5 @@ +#pragma once + #include namespace torch { -- 2.7.4