_(prim, MMBatchSide) \
_(prim, min) \
_(prim, max) \
+ _(aten, _ncf_unsqueeze) \
_(aten, warn) \
_(aten, floordiv) \
_(aten, __round_to_zero_floordiv)\
return t.accessor<scalar_t, 1>();
}
+template<typename T>
+struct InvStd {
+ T operator()(T var, double epsilon) const {
+ T invstd = 0;
+ if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
+ invstd = static_cast<T>(1) / std::sqrt(var + epsilon);
+ }
+ return invstd;
+ }
+};
+
+template<typename T>
+struct Var {
+ T operator()(T var, double epsilon) const {
+ return var;
+ }
+};
template<typename scalar_t>
-std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_template(const Tensor& input, const Tensor& weight, const Tensor& bias,
- const Tensor& running_mean, const Tensor& running_var, bool train, double momentum, double eps) {
+std::tuple<Tensor,Tensor,Tensor> batch_norm_cpu_transform_input_template(
+ const Tensor& input, const Tensor& weight, const Tensor& bias,
+ const Tensor& save_mean /* optional */, const Tensor& save_invstd /* optional */,
+ const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
+ bool train, double eps) {
- using accscalar_t = at::acc_type<scalar_t, false>;
Tensor output = at::empty_like(input);
int64_t n_input = input.size(1);
- int64_t n = input.numel() / n_input;
- Tensor save_mean;
- Tensor save_invstd;
- const int64_t zero = 0;
- if (train) {
- save_mean = at::empty({n_input}, input.options());
- save_invstd = at::empty({n_input}, input.options());
- }
auto save_mean_a = conditional_accessor_1d<scalar_t>(save_mean);
auto save_invstd_a = conditional_accessor_1d<scalar_t>(save_invstd);
auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);
parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
- for (int64_t f = b_begin; f < b_end; ++f) {
- Tensor in = input.select(1, f);
- Tensor out = output.select(1, f);
+ for (int64_t f = b_begin; f < b_end; ++f) {
+ Tensor in = input.select(1, f);
+ Tensor out = output.select(1, f);
+
+ scalar_t mean, invstd;
+ if (train) {
+ mean = save_mean_a[f];
+ invstd = save_invstd_a[f];
+ } else {
+ mean = running_mean_a[f];
+ invstd = 1 / std::sqrt(running_var_a[f] + eps);
+ }
- scalar_t mean, invstd;
+ // compute output
+ scalar_t w = weight.defined() ? weight.data<scalar_t>()[f * weight.stride(0)] : 1;
+ scalar_t b = bias.defined() ? bias.data<scalar_t>()[f * bias.stride(0)] : 0;
- if (train) {
- // compute mean per input
- accscalar_t sum = 0;
- CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
- sum += i;
- });
-
- mean = (scalar_t) (sum / n);
- save_mean_a[f] = mean;
-
- // compute variance per input
- sum = 0;
- CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
- sum += (i - mean) * (i - mean);
- });
-
- if (sum == 0 && eps == 0.0) {
- invstd = 0;
- } else {
- invstd = (scalar_t) (1 / std::sqrt(sum/n + eps));
- }
- save_invstd_a[f] = invstd;
+ CPU_tensor_apply2<scalar_t,scalar_t>(out, in, [&](scalar_t& o, const scalar_t& i) {
+ o = ((i - mean) * invstd) * w + b;
+ });
+ }
+ });
+ return std::make_tuple(output, save_mean, save_invstd);
+}
- // update running averages
- if (running_mean.defined()) {
- running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f];
- }
- if (running_var.defined()) {
- accscalar_t unbiased_var = sum / (n - 1);
- running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
- }
- } else {
- mean = running_mean_a[f];
- invstd = 1 / std::sqrt(running_var_a[f] + eps);
- }
+template<typename scalar_t, template<typename T> class VarTransform>
+std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
+ const Tensor& input, const Tensor& running_mean, const Tensor& running_var,
+ double momentum, double eps) {
- // compute output
- scalar_t w = weight.defined() ? weight.data<scalar_t>()[f * weight.stride(0)] : 1;
- scalar_t b = bias.defined() ? bias.data<scalar_t>()[f * bias.stride(0)] : 0;
+ using accscalar_t = at::acc_type<scalar_t, false>;
- CPU_tensor_apply2<scalar_t,scalar_t>(out, in, [&](scalar_t& o, const scalar_t& i) {
- o = ((i - mean) * invstd) * w + b;
- });
+ int64_t n_input = input.size(1);
+ int64_t n = input.numel() / n_input;
+
+ Tensor save_mean = at::empty({n_input}, input.options());
+ Tensor save_var_transform = at::empty({n_input}, input.options());
+ auto save_mean_a = save_mean.accessor<scalar_t, 1>();
+ auto save_var_transform_a = save_var_transform.accessor<scalar_t, 1>();
+
+ auto running_mean_a = conditional_accessor_1d<scalar_t>(running_mean);
+ auto running_var_a = conditional_accessor_1d<scalar_t>(running_var);
+
+ parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
+ for (int64_t f = b_begin; f < b_end; ++f) {
+ Tensor in = input.select(1, f);
+
+ // compute mean per input
+ accscalar_t sum = 0;
+ CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
+ sum += i;
+ });
+ scalar_t mean = sum / n;
+ save_mean_a[f] = mean;
+
+ // compute variance per input
+ accscalar_t var_sum = 0;
+ CPU_tensor_apply1<scalar_t>(in, [&] (const scalar_t& i) {
+ var_sum += (i - mean) * (i - mean);
+ });
+ save_var_transform_a[f] = VarTransform<accscalar_t>{}(var_sum / n, eps);
+
+ // update running averages
+ if (running_mean.defined()) {
+ running_mean_a[f] = momentum * mean + (1 - momentum) * running_mean_a[f];
}
- });
- return std::make_tuple(output, save_mean, save_invstd);
+ if (running_var.defined()) {
+ accscalar_t unbiased_var = var_sum / (n - 1);
+ running_var_a[f] = momentum * unbiased_var + (1 - momentum) * running_var_a[f];
+ }
+ }
+ });
+ return std::make_tuple(save_mean, save_var_transform);
}
}
}
+std::tuple<Tensor, Tensor> batch_norm_update_stats_cpu(
+ const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) {
+ return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm_update_stats", [&] {
+ return batch_norm_cpu_update_stats_template<scalar_t, Var>(self, running_mean, running_var, momentum, 0);
+ });
+}
+
std::tuple<Tensor, Tensor, Tensor> batch_norm_cpu(const Tensor& self, const Tensor& weight, const Tensor& bias,
const Tensor& running_mean, const Tensor& running_var,
bool train, double momentum, double eps) {
return AT_DISPATCH_FLOATING_TYPES(self.type(), "batch_norm", [&] {
- return batch_norm_cpu_template<scalar_t>(self, weight, bias, running_mean, running_var, train, momentum, eps);
+ if (!train) {
+ return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, {}, {}, running_mean, running_var, train, eps);
+ } else {
+ auto save_stats = batch_norm_cpu_update_stats_template<scalar_t, InvStd>(self, running_mean, running_var, momentum, eps);
+ return batch_norm_cpu_transform_input_template<scalar_t>(self, weight, bias, std::get<0>(save_stats), std::get<1>(save_stats), running_mean, running_var, train, eps);
+ }
});
}
});
}
+std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda(
+ const Tensor& self, const Tensor& running_mean, const Tensor& running_var, double momentum) {
+ return AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.type(), "batch_norm_backward", [&] {
+ if (cuda::detail::canUse32BitIndexMath(self)) {
+ return batch_norm_update_stats_cuda_template<scalar_t, int32_t>(self, running_mean, running_var, momentum);
+ } else {
+ return batch_norm_update_stats_cuda_template<scalar_t, int64_t>(self, running_mean, running_var, momentum);
+ }
+ });
+}
+
} } // namespace at::native
}
}
+template<typename T>
+struct InvStd {
+ __device__ __forceinline__ T operator()(T var, double epsilon) const {
+ T invstd = 0;
+ if (var != static_cast<T>(0) || epsilon != static_cast<T>(0)) {
+ invstd = static_cast<T>(1) / device_sqrt(var + epsilon);
+ }
+ return invstd;
+ }
+};
-template <typename scalar_t, typename accscalar_t, typename index_t>
+template<typename T>
+struct Var {
+ __device__ __forceinline__ T operator()(T var, double epsilon) const {
+ return var;
+ }
+};
+
+template <template<typename T> class VarTransform, typename scalar_t, typename accscalar_t, typename index_t>
__global__ void batch_norm_collect_statistics_kernel(
const PackedTensorAccessor<scalar_t, 3, RestrictPtrTraits, index_t> input,
const accscalar_t epsilon,
PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_mean,
PackedTensorAccessor<scalar_t, 1, RestrictPtrTraits, index_t> running_var,
PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_mean,
- PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_invstd) {
+ PackedTensorAccessor<accscalar_t, 1, RestrictPtrTraits, index_t> save_transformed_var) {
__shared__ int shared_n[2 * 2 * WARP_SIZE + WARP_SIZE];
// this writes each warps item into shared memory
// there are at most WARP_SIZE items left because
- // there are at most WARP_SIZE**2 threads at the beginning
+ // there are at most WARP_SIZE**2 threads at the beginning
__syncthreads();
if (tid % WARP_SIZE == 0) {
shared_n[tid / WARP_SIZE] = n;
// Save the mean, variance, and moving averages
if (tid == 0) {
- accscalar_t invstd = 0;
- if (var_n != static_cast<accscalar_t>(0) || epsilon != static_cast<accscalar_t>(0)) {
- invstd = static_cast<accscalar_t>(1) / device_sqrt(var_n / N + epsilon);
- }
save_mean[plane] = avg;
- save_invstd[plane] = invstd;
+ save_transformed_var[plane] = VarTransform<accscalar_t>{}(var_n / N, epsilon);
if (running_mean.data() != NULL) {
running_mean[plane] = static_cast<scalar_t>((1 - momentum) * running_mean[plane] + momentum * avg);
}
dim3 blocks(input.size(1));
tf = getNumThreads(input.size(2));
dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
- batch_norm_collect_statistics_kernel<scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
+ batch_norm_collect_statistics_kernel<InvStd, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
(input, epsilon, momentum, running_mean, running_var, save_mean, save_invstd);
batch_norm_transform_input_kernel<scalar_t, accscalar_t, true, index_t> <<<blocks_trans, threads_trans, 0, stream>>>
(input, output, save_mean, save_invstd, weight, bias, epsilon);
return std::make_tuple(grad_input_, grad_weight_, grad_bias_);
}
+template<typename scalar_t, typename index_t>
+std::tuple<Tensor, Tensor> batch_norm_update_stats_cuda_template(
+ const Tensor& input_, const Tensor& running_mean_, const Tensor& running_var_, double momentum) {
+
+ using accscalar_t = at::acc_type<scalar_t, true>;
+ int64_t n_channels = input_.size(1);
+ auto input_reshaped = input_.reshape({input_.size(0), input_.size(1), -1}); // internally we merge the feature dimensions
+
+ auto input_options = input_.options();
+ if (input_.type().scalarType() == at::ScalarType::Half) {
+ input_options = input_options.dtype(ScalarType::Float);
+ }
+ Tensor save_mean_ = at::empty({n_channels}, input_options);
+ Tensor save_var_ = at::empty({n_channels}, input_options);
+
+ auto input = input_reshaped.packed_accessor<scalar_t, 3, RestrictPtrTraits, index_t>();
+ auto running_mean = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_mean_);
+ auto running_var = packed_accessor_or_dummy<scalar_t, 1, RestrictPtrTraits, index_t>(running_var_);
+ auto save_mean = save_mean_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
+ auto save_var = save_var_.packed_accessor<accscalar_t, 1, RestrictPtrTraits, index_t>();
+ auto stream = at::cuda::getCurrentCUDAStream();
+
+ // for the reduction, we cannot use blocks for the batch dim, but if we have few threads in
+ // the feature dimension, we'll use some threads for blocks
+ dim3 blocks(input.size(1));
+ int tf = getNumThreads(input.size(2));
+ dim3 threads(tf, std::max<int>(1, MAX_BLOCK_SIZE/tf));
+ // NB: epsilon is unused by the Var transform, so we set it to 0
+ batch_norm_collect_statistics_kernel<Var, scalar_t, accscalar_t, index_t> <<<blocks, threads, 0, stream>>>
+ (input, 0., momentum, running_mean, running_var, save_mean, save_var);
+ THCudaCheck(cudaGetLastError());
+ return std::make_tuple(save_mean_, save_var_);
+
+}
+
} } // namespace at::native
CPU: batch_norm_backward_cpu
CUDA: batch_norm_backward_cuda
+- func: batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, double momentum) -> (Tensor, Tensor)
+ dispatch:
+ CPU: batch_norm_update_stats_cpu
+ CUDA: batch_norm_update_stats_cuda
+
- func: ones(IntList size, TensorOptions options={}) -> Tensor
- func: ones_out(Tensor result, IntList size) -> Tensor
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ def test_fuse_batch_norm(self):
+
+ class ResLike(torch.jit.ScriptModule):
+ def __init__(self, optimize=True):
+ super(ResLike, self).__init__(optimize)
+ self.bn = nn.BatchNorm2d(16)
+
+ @torch.jit.script_method
+ def forward(self, x, y):
+ return y + torch.relu(self.bn(x))
+
+ model = ResLike().cuda()
+ model_noopt = ResLike(optimize=False).cuda()
+ model_noopt.load_state_dict(model.state_dict())
+ x = torch.randn(2, 16, 8, 8, device='cuda')
+ y = torch.randn(2, 16, 8, 8, device='cuda')
+ # FIXME: We need differentiation for CNNs for this optimization to trigger
+ with torch.no_grad():
+ out = model(x, y)
+ graph = model.graph_for(x, y)
+ rep = str(graph)
+
+ out_noopt = model_noopt(x, y)
+ rep_noopt = str(model_noopt.graph_for(x, y))
+ self.assertEqual(out, out_noopt, prec=3e-5)
+
+ # Check that batch_norm has really been decomposed
+ self.assertIn('aten::batch_norm_update_stats', rep)
+ self.assertNotIn('aten::batch_norm(', rep)
+ self.assertIn('aten::batch_norm(', rep_noopt)
+
+ # Make sure the fusion group is big, and contains aten::sqrt, which could
+ # originate only from decomposing batch_norm in this case
+ fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
+ self.assertEqual(len(fusion_groups), 1)
+ fused_graph = fusion_groups[0].g('Subgraph')
+ self.assertTrue(any(node.kind() == 'aten::sqrt' for node in fused_graph.nodes()))
+
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ def test_threshold(self):
+ def f(x):
+ return torch.threshold(x, 0, -10) + x + x + x
+
+ x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda')
+ scripted = torch.jit.script(f)
+
+ self.assertEqual(f(x), scripted(x))
+ self.assertAllFused(scripted.graph_for(x))
+
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
@skipIfRocm
@enable_cpu_fuser
{aten::abs, "fabs(${0})"},
{aten::sigmoid, "1.f / (1.f + expf(-${0}))"},
{aten::relu, "${0} < 0 ? 0.f : ${0} "},
+ {aten::threshold,
+ "${0} <= ${1} ? static_cast<decltype(${0})>(${2}) : ${0} "},
{aten::log, "logf(${0})"},
{aten::log10, "log10f(${0})"},
{aten::log1p, "log1pf(${0})"},
#include <ATen/ExpandUtils.h>
#include <torch/csrc/jit/assertions.h>
#include <torch/csrc/jit/autodiff.h>
+#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/fuser/interface.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
+#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
+#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/symbolic_variable.h>
#include <unordered_map>
//"aten::rand_like(Tensor self) -> Tensor",
"aten::reciprocal(Tensor self) -> Tensor",
"aten::relu(Tensor self) -> Tensor",
+ "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
"aten::remainder(Tensor self, Tensor other) -> Tensor",
"aten::round(Tensor self) -> Tensor",
"aten::rsqrt(Tensor self) -> Tensor",
return true;
}
+RegisterOperators reg_bn_unsqueeze({Operator(
+ "aten::_ncf_unsqueeze(Tensor self, int ndim) -> Tensor",
+ [](const Node* node) {
+ return [](Stack& stack) {
+ const int64_t ndim = pop(stack).toInt();
+ auto self = pop(stack).toTensor();
+ c10::SmallVector<int64_t, 8> sizes(ndim, 1);
+ JIT_ASSERT(self.dim() == 1);
+ sizes.at(1) = self.size(0);
+ push(stack, self.reshape(sizes));
+ return 0;
+ };
+ })});
+
+// Yes, no, or no value if we can't tell
+c10::optional<bool> isDefined(Value* tensor) {
+ if (tensor->type()->isSubtypeOf(DynamicType::get())) {
+ return true;
+ }
+ if (tensor->node()->kind() == prim::None ||
+ tensor->node()->kind() == prim::Undefined) {
+ return false;
+ }
+ return {};
+}
+
+bool isFusableBatchNorm(Node* batch_norm) {
+ if (!batch_norm->matches(
+ "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
+ return false;
+ }
+ // If we can't determine if weight and bias is defined statically there's
+ // really no point in decomposing batch norm into simpler ops, since it won't
+ // get fused into a single kernel.
+ return isDefined(batch_norm->namedInput(attr::weight)).has_value() &&
+ isDefined(batch_norm->namedInput(attr::bias)).has_value();
+}
+
Value* broadcastSizes(at::ArrayRef<Value*> sizes) {
JIT_ASSERT(!sizes.empty());
Graph* graph = sizes[0]->owningGraph();
struct GraphFuser {
Block* block_;
+ c10::optional<AliasDb> aliasDb_;
std::shared_ptr<Graph> graph_;
GraphFuser(Block* block, std::shared_ptr<Graph> graph)
}
bool isFusable(Node* node) {
+ return isFusableMap(node) || isFusableBatchNorm(node);
+ }
+
+ bool isFusableMap(Node* node) {
// We don't want to bother with cross-block node movements, as they
// are not necessarily correct.
if (node->owningBlock() != block_)
// cannot be fused because it is not a simple map, can be put in a fusion
// group as long as no items in the group read the output of concat
bool isFusableAsExitNode(Node* node) {
- return isFusable(node) || isFusableOnlyAsExitNode(node);
+ return isFusableMap(node) || isFusableOnlyAsExitNode(node);
}
bool isFusableOnlyAsExitNode(Node* node) {
return *n->g(attr::Subgraph);
}
+ void decomposeBatchNorm(Node* batch_norm) {
+ static std::shared_ptr<Graph> bn_graph;
+ static std::once_flag flag;
+ std::call_once(
+ flag,
+ [](std::shared_ptr<Graph>* graph_ptr) {
+ static const char* source = R"SCRIPT(
+ def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
+ if training:
+ norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
+ else:
+ norm_mean = torch._unwrap_optional(running_mean)
+ norm_var = torch._unwrap_optional(running_var)
+ norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
+ norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
+ norm_invstd = 1 / (eps + torch.sqrt(norm_var))
+ return ((input - norm_mean) * norm_invstd)
+ )SCRIPT";
+ auto module = std::make_shared<script::Module>();
+ defineMethodsInModule(
+ module, source, script::nativeResolver, /*self=*/nullptr);
+ *graph_ptr = module->get_method("batch_norm").graph();
+ },
+ &bn_graph);
+
+ JIT_ASSERT(isFusableBatchNorm(batch_norm));
+ WithInsertPoint insert_guard{batch_norm};
+ Value* input = batch_norm->namedInput(attr::input);
+ Value* input_dim = graph_->insert(aten::dim, {input});
+ std::vector<Value*> inputs{input,
+ batch_norm->namedInput(attr::running_mean),
+ batch_norm->namedInput(attr::running_var),
+ batch_norm->namedInput(attr::training),
+ batch_norm->namedInput(attr::momentum),
+ batch_norm->namedInput(attr::eps)};
+ Value* new_output =
+ SubgraphUtils::inlineGraph(bn_graph, inputs, batch_norm).at(0);
+ auto weight = batch_norm->namedInput(attr::weight);
+ auto bias = batch_norm->namedInput(attr::bias);
+ if (isDefined(weight).value()) {
+ Value* expanded_weight =
+ graph_->insert(aten::_ncf_unsqueeze, {weight, input_dim});
+ new_output = graph_->insert(aten::mul, {new_output, expanded_weight});
+ }
+ if (isDefined(bias).value()) {
+ Value* expanded_bias =
+ graph_->insert(aten::_ncf_unsqueeze, {bias, input_dim});
+ new_output = graph_->insert(aten::add, {new_output, expanded_bias});
+ }
+ batch_norm->output()->replaceAllUsesWith(new_output);
+ batch_norm->destroy();
+ }
+
void mergeFusionGroups(Node* consumer_group, Node* producer_group) {
// Now we have two fusion groups!
// Revert the fusion - place all inner nodes of producer back in the outer
*insertion_point = n;
}
- at::optional<Node*> tryFuse(
- Node* consumer,
- Value* producer,
- const AliasDb& aliasDb) {
+ at::optional<Node*> tryFuse(Node* consumer, Value* producer) {
// this handles cases where producer can be moved _into_ the fusion group of
// consumer.
// TODO: extend to fusion of consumer into _producer's_ fusion blob
// consumer. Fusion will rewrite those later uses to use the version of
// producer generated by the fused blob. In this case, producer becomes
// an output of the fusion group.
- producer->node()->moveBeforeTopologicallyValid(real_consumer, aliasDb);
+ producer->node()->moveBeforeTopologicallyValid(
+ real_consumer, aliasDb_.value());
if (!shouldFuse) {
return at::nullopt;
} else if (consumer->kind() != prim::FusionGroup) {
group = createSingletonFusionGroup(consumer);
}
+ if (producer->node()->matches(
+ "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
+ // We don't do any fusions in here, but simply decompose the batch norm
+ // into a kernel that computes the stats + pointwise ops which will be
+ // considered in this fusion next.
+ decomposeBatchNorm(producer->node());
+ return group;
+ }
if (producer->node()->kind() == prim::FusionGroup) {
mergeFusionGroups(group, producer->node());
return group;
chunk->inputs().begin(),
chunk->inputs().end(),
[&](Value* producer_for_chunk) {
- return isFusable(producer_for_chunk->node()) &&
+ return isFusableMap(producer_for_chunk->node()) &&
allUsersAreThisConsumerOrCalcSizes(chunk, producer_for_chunk);
});
if (it == chunk->inputs().end()) {
}
// returns where to continue scanning, and whether any fusion was made
- std::pair<graph_node_list::iterator, bool> scanNode(
- Node* consumer,
- const AliasDb& aliasDb) {
+ std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
if (isFusableAsExitNode(consumer)) {
auto consumer_inputs = consumer->kind() == aten::cat
? consumer->namedInput(attr::tensors)->node()->inputs()
// we scan this consumer again to perform the fusion
return std::make_pair(consumer->reverseIterator(), true);
}
- auto fusion_group = tryFuse(consumer, producer, aliasDb);
+ auto fusion_group = tryFuse(consumer, producer);
if (fusion_group) {
// after fusion, consumer moves into a FusionGroup, so inputs is no
// longer valid so we rescan the new FusionGroup for more fusions...
}
}
+ void refreshAliasDb() {
+ aliasDb_ = AliasAnalysis(graph_);
+ }
+
void run() {
// Run the pass until no changes are made.
// This is neccessary, because the algorithm can miss out on certain fusion
bool any_changed = true;
while (any_changed) {
any_changed = false;
- auto aliasDb = AliasAnalysis(graph_);
+ refreshAliasDb();
for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
bool changed;
- std::tie(it, changed) = scanNode(*it, aliasDb);
+ std::tie(it, changed) = scanNode(*it);
any_changed |= changed;
}
}
namespace jit {
namespace SubgraphUtils {
namespace {
-bool isSubgraphNodeKind(Symbol s) {
- return s == prim::DifferentiableGraph || s == prim::FusionGroup;
-}
-bool isSubgraphNodeKind(Node* n) {
- return isSubgraphNodeKind(n->kind());
+bool hasSubgraph(Node* n) {
+ return n->hasAttribute(attr::Subgraph);
}
// Combine the nodes in two subgraph together. The nodes will end up in
// `mergeTo`, and `mergeFrom` is destroyed.
void mergeSubgraph(Node* mergeTo, Node* mergeFrom) {
- const auto nodes = unmergeSubgraph(mergeFrom);
- for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
- mergeNodeIntoSubgraph(*it, mergeTo);
+ Node* nodeBeforeMergeFrom = mergeFrom->prev();
+ Node* nodeAfterMergeFrom = mergeFrom->next();
+ unmergeSubgraph(mergeFrom);
+ std::vector<Node*> nodes;
+ const auto end_it = nodeBeforeMergeFrom->reverseIterator();
+ auto it = nodeAfterMergeFrom->reverseIterator();
+ ++it;
+ while (it != end_it) {
+ // NB: mergeNodeIntoSubgraph destroys node, hence the complications
+ Node* node = *it;
+ ++it;
+ mergeNodeIntoSubgraph(node, mergeTo);
}
}
} // namespace
std::shared_ptr<Graph> getSubgraph(Node* n) {
- JIT_ASSERT(isSubgraphNodeKind(n));
return n->g(attr::Subgraph);
}
-std::vector<Node*> unmergeSubgraph(Node* subgraphNode) {
+void unmergeSubgraph(Node* subgraphNode) {
JIT_ASSERT(subgraphNode->kind() == prim::DifferentiableGraph);
- auto outerGraph = subgraphNode->owningGraph();
-
- std::vector<Node*> temporary_nodes;
- auto subgraph = getSubgraph(subgraphNode);
-
- // Initialize a map of inner graph values to outer graph values
- std::unordered_map<const Value*, Value*> innerToOuter;
- const auto innerInputs = subgraph->inputs();
- const auto outerInputs = subgraphNode->inputs();
- for (size_t i = 0; i < innerInputs.size(); ++i) {
- innerToOuter[innerInputs[i]] = outerInputs[i];
- }
-
- // Clone all nodes
- for (auto inner : subgraph->nodes()) {
- Node* outer = outerGraph->createClone(
- inner, [&](Value* k) -> Value* { return innerToOuter.at(k); });
- outer->insertBefore(subgraphNode);
- temporary_nodes.emplace_back(outer);
- const auto innerOutputs = inner->outputs();
- const auto outerOutputs = outer->outputs();
- for (size_t i = 0; i < innerOutputs.size(); ++i) {
- innerToOuter[innerOutputs[i]] = outerOutputs[i];
- }
- }
- // Replace uses of group outputs and destroy the group
- const auto subgraphOutputs = subgraph->outputs();
+ // Inline the graph, replace uses of node outputs and destroy the node
+ const auto subgraphOutputs = inlineGraph(
+ getSubgraph(subgraphNode), subgraphNode->inputs(), subgraphNode);
JIT_ASSERT(subgraphOutputs.size() >= subgraphNode->outputs().size());
for (size_t i = 0; i < subgraphNode->outputs().size(); ++i) {
- const auto outerOutput = innerToOuter.at(subgraphOutputs[i]);
- subgraphNode->outputs()[i]->replaceAllUsesWith(outerOutput);
+ subgraphNode->outputs()[i]->replaceAllUsesWith(subgraphOutputs[i]);
}
subgraphNode->destroy();
-
- return temporary_nodes;
}
void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode) {
- JIT_ASSERT(isSubgraphNodeKind(subgraphNode));
- if (isSubgraphNodeKind(toMerge)) {
+ JIT_ASSERT(hasSubgraph(subgraphNode));
+ if (hasSubgraph(toMerge)) {
return mergeSubgraph(subgraphNode, toMerge);
}
toMerge->destroy();
}
+// Invariant we depend on in mergeSubgraph: All inlined nodes are created
+// between the node preceding insertBefore and insertBefore.
+std::vector<Value*> inlineGraph(
+ const std::shared_ptr<Graph>& subgraph,
+ at::ArrayRef<Value*> outerInputs,
+ Node* insertBefore) {
+ auto outerGraph = insertBefore->owningGraph();
+
+ // Initialize a map of inner graph values to outer graph values
+ std::unordered_map<const Value*, Value*> innerToOuter;
+ const auto innerInputs = subgraph->inputs();
+ JIT_ASSERT(outerInputs.size() == innerInputs.size());
+ for (size_t i = 0; i < innerInputs.size(); ++i) {
+ innerToOuter[innerInputs[i]] = outerInputs[i];
+ }
+
+ // Clone all nodes
+ for (auto inner : subgraph->nodes()) {
+ Node* outer = outerGraph->createClone(
+ inner, [&](Value* k) -> Value* { return innerToOuter.at(k); });
+ outer->insertBefore(insertBefore);
+ const auto innerOutputs = inner->outputs();
+ const auto outerOutputs = outer->outputs();
+ for (size_t i = 0; i < innerOutputs.size(); ++i) {
+ innerToOuter[innerOutputs[i]] = outerOutputs[i];
+ }
+ }
+
+ return fmap(subgraph->outputs(), [&](Value* output) {
+ return innerToOuter.at(output);
+ });
+}
+
Node* createSingletonSubgraph(Node* n, Symbol subgraphKind) {
- JIT_ASSERT(isSubgraphNodeKind(subgraphKind));
auto graph = n->owningGraph();
auto subgraph = graph->create(subgraphKind, 0);
subgraph->g_(attr::Subgraph, std::make_shared<Graph>(graph->current_scope()));
mergeNodeIntoSubgraph(n, subgraph);
return subgraph;
}
+
} // namespace SubgraphUtils
} // namespace jit
} // namespace torch
// Move nodes from a subgraph node to the outer graph.
// `subgraphNode` is destroyed.
-std::vector<Node*> unmergeSubgraph(Node* subgraphNode);
+void unmergeSubgraph(Node* subgraphNode);
// Convenience function
std::shared_ptr<Graph> getSubgraph(Node* n);
+std::vector<Value*> inlineGraph(
+ const std::shared_ptr<Graph>& subgraph,
+ at::ArrayRef<Value*> outerInputs,
+ Node* insertBefore);
+
} // namespace SubgraphUtils
} // namespace jit
} // namespace torch