From: Raghavan Raman Date: Wed, 25 Aug 2021 18:12:57 +0000 (-0700) Subject: [Static Runtime] Moved NNC operator definitions to separate files. (#63838) X-Git-Tag: accepted/tizen/unified/20221215.051042~728 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a2399a76e18b31747019ead8f80e2a4eb53e8223;p=platform%2Fupstream%2Fpytorch.git [Static Runtime] Moved NNC operator definitions to separate files. (#63838) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63838 Refactored NNC operator definitions code into separate files. Made `TEWrapper` a class with a fixed set of methods and added separate definitions for them based on `TORCH_ENABLE_LLVM` to keep the same functionality as before. Test Plan: Build and ran Static Runtime tests. Reviewed By: hlu1 Differential Revision: D30405467 fbshipit-source-id: 606ef852bb820d5e23a0f8af1bf5dc122e90bceb --- diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index 157c30663c..0d888ea8a4 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -329,6 +329,7 @@ core_sources_full = core_sources_full_mobile + [ "torch/csrc/jit/runtime/static/native_ops.cpp", "torch/csrc/jit/runtime/static/ops.cpp", "torch/csrc/jit/runtime/static/passes.cpp", + "torch/csrc/jit/runtime/static/te_wrapper.cpp", "torch/csrc/jit/tensorexpr/external_functions.cpp", "torch/csrc/jit/tensorexpr/external_functions_codegen.cpp", ] diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index e6af641083..d9fb9bad08 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -493,198 +494,6 @@ REGISTER_OPERATOR_FUNCTOR(aten::leaky_relu, aten_leaky_relu, [](Node* n) -> SROp }; }); -namespace { - -// Use the width of an AVX-512 vector by default; this happens to work OK -// for AVX2 as well. Some ops benefit from using multiple AVX ports, in -// which case they are vectorized by twice this constant. An exception is -// logit, since it contains FP divide, which is single-ported. -static constexpr int kVectorWidth = 16; - -#ifdef TORCH_ENABLE_LLVM - -struct TEWrapper { - std::unique_ptr cg; - TEWrapper() = default; - void update(std::unique_ptr&& cg_) { - cg = std::move(cg_); - } - - void call(const std::vector& args) { - cg->call_raw(args); - } - - inline bool supports(const at::Tensor& t) { - return t.is_contiguous() && t.dtype().Match(); - } -}; - -void optimizePointwise( - tensorexpr::LoopNest* ln, - tensorexpr::Tensor target, - int width) { - using namespace torch::jit::tensorexpr; - std::vector loops = ln->getLoopStmtsFor(target); - ForPtr inner, tail; - TORCH_CHECK(loops.size() > 0, "No loops created for pointwise op"); - ln->splitWithTail(loops[0], width, &inner, &tail); - ln->vectorize(inner); -} - -std::shared_ptr wrapTECompute( - std::shared_ptr wrap, - tensorexpr::Placeholder& in, - tensorexpr::Tensor out, - tensorexpr::VarHandle& dim, - int width = kVectorWidth) { - using namespace torch::jit::tensorexpr; - LoopNest ln({out}); - optimizePointwise(&ln, out, width); - ln.prepareForCodegen(); - StmtPtr s = ln.root_stmt(); - s = tensorexpr::IRSimplifier::simplify(s); - std::vector args; - args.emplace_back(out); - args.emplace_back(in); - args.emplace_back(dim); - auto cg = std::make_unique(s, args); - wrap->update(std::move(cg)); - return wrap; -}; - -#else - -struct TEWrapper { - TEWrapper() = default; - template - void operator()(const Ts&... ts) { - DCHECK(0 && "Invalid call"); - } - void call(const std::vector& args) { - DCHECK(0 && "Invalid call"); - } - - inline bool supports(const at::Tensor& t) { - return false; - } -}; - -std::shared_ptr wrapTECompute( - std::shared_ptr wrap, - tensorexpr::Placeholder& in, - tensorexpr::Tensor out, - tensorexpr::VarHandle& dim, - int width = kVectorWidth) { - return wrap; -}; - -#endif - -std::mutex& getNNCCacheMutex() { - static std::mutex nncCacheMutex; - return nncCacheMutex; -} - -std::unordered_map>& getNNCCache() { - static std::unordered_map> nncCache; - return nncCache; -} - -std::shared_ptr lookupNNCCache(NodeKind kind) { - std::lock_guard lock(getNNCCacheMutex()); - auto it = getNNCCache().find(kind); - if (it != getNNCCache().end()) { - return it->second; - } - return nullptr; -} - -void updateNNCCache(NodeKind kind, std::shared_ptr code) { - std::lock_guard lock(getNNCCacheMutex()); - getNNCCache()[kind] = code; -} - -} // namespace - -std::shared_ptr createLogit(c10::optional clamp) { - using namespace torch::jit::tensorexpr; - // TODO: Use NNC cache for this op. - auto wrap = std::make_shared(); - auto N = VarHandle("N", kInt); - Placeholder A("A", kFloat, {N}); - tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) { - auto A_elem = [&]() { - if (!clamp) { - return A.load(i); - } else { - auto elem = A.load(i); - auto min = FloatImm::make(*clamp); - auto max = FloatImm::make(1.0f - *clamp); - elem = CompareSelect::make(elem, min, min, elem, kLT); - return CompareSelect::make(elem, max, max, elem, kGT); - } - }(); - return log_vml(A_elem / (FloatImm::make(1.0f) - A_elem)); - }); - return wrapTECompute(wrap, A, B, N); -} - -std::shared_ptr createRelu() { - using namespace torch::jit::tensorexpr; - auto wrap = lookupNNCCache(aten::relu); - if (wrap) { - return wrap; - } - wrap = std::make_shared(); - auto N = VarHandle("N", kInt); - Placeholder A("A", kFloat, {N}); - tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) { - auto zero = FloatImm::make(0.f); - auto a = A.load(i); - return ifThenElse(a < zero, zero, a); - }); - wrap = wrapTECompute(wrap, A, B, N); - updateNNCCache(aten::relu, wrap); - return wrap; -} - -std::shared_ptr createTanh() { - using namespace torch::jit::tensorexpr; - auto wrap = lookupNNCCache(aten::tanh); - if (wrap) { - return wrap; - } - wrap = std::make_shared(); - auto N = VarHandle("N", kInt); - Placeholder A("A", kFloat, {N}); - tensorexpr::Tensor B = Compute("B", {N}, [&](const VarHandle& i) { - auto a = A.load(i); - return fast_tanh(a); - }); - wrap = wrapTECompute(wrap, A, B, N); - updateNNCCache(aten::tanh, wrap); - return wrap; -} - -std::shared_ptr createSigmoid() { - using namespace torch::jit::tensorexpr; - auto wrap = lookupNNCCache(aten::sigmoid); - if (wrap) { - return wrap; - } - wrap = std::make_shared(); - auto N = VarHandle("N", kInt); - Placeholder A("A", kFloat, {N}); - Tensor B = - Compute("B", {N}, [&](const VarHandle& i) { return sigmoid(A.load(i)); }); - // NNC uses sleef for vectorizing sigmoid, which comes in an 8-wide flavor - // (Sleef_expf8). - constexpr int kSleefWidth = 8; - wrap = wrapTECompute(wrap, A, B, N, kSleefWidth); - updateNNCCache(aten::sigmoid, wrap); - return wrap; -} - REGISTER_OPERATOR_FUNCTOR(aten::relu, aten_relu, [](Node* n) -> SROperator { if (!n->matches(torch::schema("aten::relu(Tensor self) -> Tensor"))) { LogAndDumpSchema(n); diff --git a/torch/csrc/jit/runtime/static/te_wrapper.cpp b/torch/csrc/jit/runtime/static/te_wrapper.cpp new file mode 100644 index 0000000000..413d44348a --- /dev/null +++ b/torch/csrc/jit/runtime/static/te_wrapper.cpp @@ -0,0 +1,184 @@ +#include + +#include +#include + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +// Use the width of an AVX-512 vector by default; this happens to work OK for +// AVX2 as well. Some ops benefit from using multiple AVX ports, in which case +// they are vectorized by twice this constant. An exception is logit, since it +// contains FP divide, which is single-ported. +static constexpr int kVectorWidth = 16; + +#ifdef TORCH_ENABLE_LLVM + +void TEWrapper::update(std::unique_ptr&& cg_) { + cg = std::move(cg_); +} + +void TEWrapper::call(const std::vector& args) { + cg->call_raw(args); +} + +bool TEWrapper::supports(const at::Tensor& t) { + return t.is_contiguous() && t.dtype().Match(); +} + +void optimizePointwise(LoopNest* ln, Tensor target, int width) { + std::vector loops = ln->getLoopStmtsFor(target); + ForPtr inner, tail; + TORCH_CHECK(loops.size() > 0, "No loops created for pointwise op"); + ln->splitWithTail(loops[0], width, &inner, &tail); + ln->vectorize(inner); +} + +std::shared_ptr wrapTECompute( + std::shared_ptr wrap, + Placeholder& in, + Tensor out, + VarHandle& dim, + int width = kVectorWidth) { + LoopNest ln({out}); + optimizePointwise(&ln, out, width); + ln.prepareForCodegen(); + StmtPtr s = ln.root_stmt(); + s = IRSimplifier::simplify(s); + std::vector args; + args.emplace_back(out); + args.emplace_back(in); + args.emplace_back(dim); + auto cg = std::make_unique(s, args); + wrap->update(std::move(cg)); + return wrap; +}; + +#else + +void TEWrapper::call(const std::vector& args) { + DCHECK(0 && "Invalid call"); +} + +bool TEWrapper::supports(const at::Tensor& t) { + return false; +} + +std::shared_ptr wrapTECompute( + std::shared_ptr wrap, + Placeholder& in, + Tensor out, + VarHandle& dim, + int width = kVectorWidth) { + return wrap; +}; + +#endif + +namespace { + +std::mutex& getNNCCacheMutex() { + static std::mutex nncCacheMutex; + return nncCacheMutex; +} + +std::unordered_map>& getNNCCache() { + static std::unordered_map> nncCache; + return nncCache; +} + +std::shared_ptr lookupNNCCache(NodeKind kind) { + std::lock_guard lock(getNNCCacheMutex()); + auto it = getNNCCache().find(kind); + if (it != getNNCCache().end()) { + return it->second; + } + return nullptr; +} + +void updateNNCCache(NodeKind kind, std::shared_ptr code) { + std::lock_guard lock(getNNCCacheMutex()); + getNNCCache()[kind] = code; +} + +} // namespace + +std::shared_ptr createLogit(c10::optional clamp) { + // TODO: Use NNC cache for this op. + auto wrap = std::make_shared(); + auto N = VarHandle("N", kInt); + Placeholder A("A", kFloat, {N}); + Tensor B = Compute("B", {N}, [&](const VarHandle& i) { + auto A_elem = [&]() { + if (!clamp) { + return A.load(i); + } else { + auto elem = A.load(i); + auto min = FloatImm::make(*clamp); + auto max = FloatImm::make(1.0f - *clamp); + elem = CompareSelect::make(elem, min, min, elem, kLT); + return CompareSelect::make(elem, max, max, elem, kGT); + } + }(); + return log_vml(A_elem / (FloatImm::make(1.0f) - A_elem)); + }); + return wrapTECompute(wrap, A, B, N); +} + +std::shared_ptr createRelu() { + auto wrap = lookupNNCCache(aten::relu); + if (wrap) { + return wrap; + } + wrap = std::make_shared(); + auto N = VarHandle("N", kInt); + Placeholder A("A", kFloat, {N}); + Tensor B = Compute("B", {N}, [&](const VarHandle& i) { + auto zero = FloatImm::make(0.f); + auto a = A.load(i); + return ifThenElse(a < zero, zero, a); + }); + wrap = wrapTECompute(wrap, A, B, N); + updateNNCCache(aten::relu, wrap); + return wrap; +} + +std::shared_ptr createTanh() { + auto wrap = lookupNNCCache(aten::tanh); + if (wrap) { + return wrap; + } + wrap = std::make_shared(); + auto N = VarHandle("N", kInt); + Placeholder A("A", kFloat, {N}); + Tensor B = Compute("B", {N}, [&](const VarHandle& i) { + auto a = A.load(i); + return fast_tanh(a); + }); + wrap = wrapTECompute(wrap, A, B, N); + updateNNCCache(aten::tanh, wrap); + return wrap; +} + +std::shared_ptr createSigmoid() { + auto wrap = lookupNNCCache(aten::sigmoid); + if (wrap) { + return wrap; + } + wrap = std::make_shared(); + auto N = VarHandle("N", kInt); + Placeholder A("A", kFloat, {N}); + Tensor B = + Compute("B", {N}, [&](const VarHandle& i) { return sigmoid(A.load(i)); }); + // NNC uses sleef for vectorizing sigmoid, which comes in an 8-wide flavor + // (Sleef_expf8). + constexpr int kSleefWidth = 8; + wrap = wrapTECompute(wrap, A, B, N, kSleefWidth); + updateNNCCache(aten::sigmoid, wrap); + return wrap; +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/runtime/static/te_wrapper.h b/torch/csrc/jit/runtime/static/te_wrapper.h new file mode 100644 index 0000000000..776602dc3e --- /dev/null +++ b/torch/csrc/jit/runtime/static/te_wrapper.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { + +class TEWrapper { + public: + TEWrapper() = default; + void call(const std::vector& args); + bool supports(const at::Tensor& t); +#ifdef TORCH_ENABLE_LLVM + void update(std::unique_ptr&& cg_); +#endif + + private: +#ifdef TORCH_ENABLE_LLVM + std::unique_ptr cg; +#endif +}; + +std::shared_ptr createLogit(c10::optional clamp); +std::shared_ptr createRelu(); +std::shared_ptr createTanh(); +std::shared_ptr createSigmoid(); + +} // namespace jit +} // namespace torch